Esempio n. 1
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    physical_devices = tf.config.list_physical_devices('GPU')
    try:
        tf.config.experimental.set_memory_growth(physical_devices[0], True)
        print("Allowed set_memory_growth for physical GPU device 0.")
    except:
        # Invalid device or cannot modify virtual devices once initialized.
        print("Could not set_memory_growth for physical GPU device 0",
              file=sys.stderr)
        pass
    #-------------------------------------------------------------
    # Log flags used.
    logging.info('Running training script with\n')
    logging.info('logdir= %s', FLAGS.logdir)
    logging.info('initial_lr= %f', FLAGS.initial_lr)
    logging.info('block3_strides= %s', str(FLAGS.block3_strides))

    # ------------------------------------------------------------
    # Create the strategy.
    strategy = tf.distribute.MirroredStrategy()
    logging.info('Number of devices: %d', strategy.num_replicas_in_sync)
    if FLAGS.debug:
        print('Number of devices:', strategy.num_replicas_in_sync)

    max_iters = FLAGS.max_iters
    global_batch_size = FLAGS.batch_size
    image_size = FLAGS.image_size
    num_eval_batches = int(50000 / global_batch_size)
    report_interval = 100
    eval_interval = 1000
    save_interval = 1000

    initial_lr = FLAGS.initial_lr

    clip_val = tf.constant(10.0)

    if FLAGS.debug:
        tf.config.run_functions_eagerly(True)
        global_batch_size = 4
        max_iters = 100
        num_eval_batches = 1
        save_interval = 1
        report_interval = 10

    # Determine the number of classes based on the version of the dataset.
    gld_info = gld.GoogleLandmarksInfo()
    num_classes = gld_info.num_classes[FLAGS.dataset_version]

    # ------------------------------------------------------------
    # Create the distributed train/validation sets.
    train_dataset = gld.CreateDataset(file_pattern=FLAGS.train_file_pattern,
                                      batch_size=global_batch_size,
                                      image_size=image_size,
                                      augmentation=FLAGS.use_augmentation,
                                      seed=FLAGS.seed)
    validation_dataset = gld.CreateDataset(
        file_pattern=FLAGS.validation_file_pattern,
        batch_size=global_batch_size,
        image_size=image_size,
        augmentation=False,
        seed=FLAGS.seed)

    train_dist_dataset = strategy.experimental_distribute_dataset(
        train_dataset)
    validation_dist_dataset = strategy.experimental_distribute_dataset(
        validation_dataset)

    train_iter = iter(train_dist_dataset)
    validation_iter = iter(validation_dist_dataset)

    # Create a checkpoint directory to store the checkpoints.
    checkpoint_prefix = os.path.join(FLAGS.logdir, 'delf_tf2-ckpt')

    # ------------------------------------------------------------
    # Finally, we do everything in distributed scope.
    with strategy.scope():
        # Compute loss.
        # Set reduction to `none` so we can do the reduction afterwards and divide
        # by global batch size.
        loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.NONE)

        def compute_loss(labels, predictions):
            per_example_loss = loss_object(labels, predictions)
            return tf.nn.compute_average_loss(
                per_example_loss, global_batch_size=global_batch_size)

        # Set up metrics.
        desc_validation_loss = tf.keras.metrics.Mean(
            name='desc_validation_loss')
        attn_validation_loss = tf.keras.metrics.Mean(
            name='attn_validation_loss')
        desc_train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
            name='desc_train_accuracy')
        attn_train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
            name='attn_train_accuracy')
        desc_validation_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
            name='desc_validation_accuracy')
        attn_validation_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
            name='attn_validation_accuracy')

        # ------------------------------------------------------------
        # Setup DELF model and optimizer.
        model = create_model(num_classes)
        logging.info('Model, datasets loaded.\nnum_classes= %d', num_classes)

        optimizer = tf.keras.optimizers.SGD(learning_rate=initial_lr,
                                            momentum=0.9)

        # Setup summary writer.
        summary_writer = tf.summary.create_file_writer(os.path.join(
            FLAGS.logdir, 'train_logs'),
                                                       flush_millis=10000)

        # Setup checkpoint directory.
        checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
        manager = tf.train.CheckpointManager(checkpoint,
                                             checkpoint_prefix,
                                             max_to_keep=10,
                                             keep_checkpoint_every_n_hours=3)
        # Restores the checkpoint, if existing.
        checkpoint.restore(manager.latest_checkpoint)

        # ------------------------------------------------------------
        # Train step to run on one GPU.
        def train_step(inputs):
            """Train one batch."""
            images, labels = inputs
            # Temporary workaround to avoid some corrupted labels.
            labels = tf.clip_by_value(labels, 0, model.num_classes)

            def _backprop_loss(tape, loss, weights):
                """Backpropogate losses using clipped gradients.

        Args:
          tape: gradient tape.
          loss: scalar Tensor, loss value.
          weights: keras model weights.
        """
                gradients = tape.gradient(loss, weights)
                clipped, _ = tf.clip_by_global_norm(gradients,
                                                    clip_norm=clip_val)
                optimizer.apply_gradients(zip(clipped, weights))

            # Record gradients and loss through backbone.
            with tf.GradientTape() as gradient_tape:
                # Make a forward pass to calculate prelogits.
                (desc_prelogits, attn_prelogits, attn_scores, backbone_blocks,
                 dim_expanded_features,
                 _) = model.global_and_local_forward_pass(images)

                # Calculate global loss by applying the descriptor classifier.
                if FLAGS.delg_global_features:
                    desc_logits = model.desc_classification(
                        desc_prelogits, labels)
                else:
                    desc_logits = model.desc_classification(desc_prelogits)
                desc_loss = compute_loss(labels, desc_logits)

                # Calculate attention loss by applying the attention block classifier.
                attn_logits = model.attn_classification(attn_prelogits)
                attn_loss = compute_loss(labels, attn_logits)

                # Calculate reconstruction loss between the attention prelogits and the
                # backbone.
                if FLAGS.use_autoencoder:
                    block3 = tf.stop_gradient(backbone_blocks['block3'])
                    reconstruction_loss = tf.math.reduce_mean(
                        tf.keras.losses.MSE(block3, dim_expanded_features))
                else:
                    reconstruction_loss = 0

                # Cumulate global loss, attention loss and reconstruction loss.
                total_loss = (
                    desc_loss + FLAGS.attention_loss_weight * attn_loss +
                    FLAGS.reconstruction_loss_weight * reconstruction_loss)

            # Perform backpropagation through the descriptor and attention layers
            # together. Note that this will increment the number of iterations of
            # "optimizer".
            _backprop_loss(gradient_tape, total_loss, model.trainable_weights)

            # Step number, for summary purposes.
            global_step = optimizer.iterations

            # Input image-related summaries.
            tf.summary.image('batch_images', (images + 1.0) / 2.0,
                             step=global_step)
            tf.summary.scalar('image_range/max',
                              tf.reduce_max(images),
                              step=global_step)
            tf.summary.scalar('image_range/min',
                              tf.reduce_min(images),
                              step=global_step)

            # Attention and sparsity summaries.
            _attention_summaries(attn_scores, global_step)
            activations_zero_fractions = {
                'sparsity/%s' % k: tf.nn.zero_fraction(v)
                for k, v in backbone_blocks.items()
            }
            for k, v in activations_zero_fractions.items():
                tf.summary.scalar(k, v, step=global_step)

            # Scaling factor summary for cosine logits for a DELG model.
            if FLAGS.delg_global_features:
                tf.summary.scalar('desc/scale_factor',
                                  model.scale_factor,
                                  step=global_step)

            # Record train accuracies.
            _record_accuracy(desc_train_accuracy, desc_logits, labels)
            _record_accuracy(attn_train_accuracy, attn_logits, labels)

            return desc_loss, attn_loss, reconstruction_loss

        # ------------------------------------------------------------
        def validation_step(inputs):
            """Validate one batch."""
            images, labels = inputs
            labels = tf.clip_by_value(labels, 0, model.num_classes)

            # Get descriptor predictions.
            blocks = {}
            prelogits = model.backbone(images,
                                       intermediates_dict=blocks,
                                       training=False)
            if FLAGS.delg_global_features:
                logits = model.desc_classification(prelogits,
                                                   labels,
                                                   training=False)
            else:
                logits = model.desc_classification(prelogits, training=False)
            softmax_probabilities = tf.keras.layers.Softmax()(logits)

            validation_loss = loss_object(labels, logits)
            desc_validation_loss.update_state(validation_loss)
            desc_validation_accuracy.update_state(labels,
                                                  softmax_probabilities)

            # Get attention predictions.
            block3 = blocks['block3']  # pytype: disable=key-error
            prelogits, _, _ = model.attention(block3, training=False)

            logits = model.attn_classification(prelogits, training=False)
            softmax_probabilities = tf.keras.layers.Softmax()(logits)

            validation_loss = loss_object(labels, logits)
            attn_validation_loss.update_state(validation_loss)
            attn_validation_accuracy.update_state(labels,
                                                  softmax_probabilities)

            return desc_validation_accuracy.result(
            ), attn_validation_accuracy.result()

        # `run` replicates the provided computation and runs it
        # with the distributed input.
        @tf.function
        def distributed_train_step(dataset_inputs):
            """Get the actual losses."""
            # Each (desc, attn) is a list of 3 losses - crossentropy, reg, total.
            desc_per_replica_loss, attn_per_replica_loss, recon_per_replica_loss = (
                strategy.run(train_step, args=(dataset_inputs, )))

            # Reduce over the replicas.
            desc_global_loss = strategy.reduce(tf.distribute.ReduceOp.SUM,
                                               desc_per_replica_loss,
                                               axis=None)
            attn_global_loss = strategy.reduce(tf.distribute.ReduceOp.SUM,
                                               attn_per_replica_loss,
                                               axis=None)
            recon_global_loss = strategy.reduce(tf.distribute.ReduceOp.SUM,
                                                recon_per_replica_loss,
                                                axis=None)

            return desc_global_loss, attn_global_loss, recon_global_loss

        @tf.function
        def distributed_validation_step(dataset_inputs):
            return strategy.run(validation_step, args=(dataset_inputs, ))

        # ------------------------------------------------------------
        # *** TRAIN LOOP ***
        with summary_writer.as_default():
            record_cond = lambda: tf.equal(
                optimizer.iterations % report_interval, 0)
            with tf.summary.record_if(record_cond):
                global_step_value = optimizer.iterations.numpy()

                # TODO(dananghel): try to load pretrained weights at backbone creation.
                # Load pretrained weights for ResNet50 trained on ImageNet.
                if (FLAGS.imagenet_checkpoint
                        is not None) and (not global_step_value):
                    logging.info(
                        'Attempting to load ImageNet pretrained weights.')
                    input_batch = next(train_iter)
                    _, _, _ = distributed_train_step(input_batch)
                    model.backbone.restore_weights(FLAGS.imagenet_checkpoint)
                    logging.info('Done.')
                else:
                    logging.info('Skip loading ImageNet pretrained weights.')
                if FLAGS.debug:
                    model.backbone.log_weights()

                last_summary_step_value = None
                last_summary_time = None
                while global_step_value < max_iters:
                    # input_batch : images(b, h, w, c), labels(b,).
                    try:
                        input_batch = next(train_iter)
                    except tf.errors.OutOfRangeError:
                        # Break if we run out of data in the dataset.
                        logging.info(
                            'Stopping training at global step %d, no more data',
                            global_step_value)
                        break

                    # Set learning rate and run the training step over num_gpu gpus.
                    optimizer.learning_rate = _learning_rate_schedule(
                        optimizer.iterations.numpy(), max_iters, initial_lr)
                    desc_dist_loss, attn_dist_loss, recon_dist_loss = (
                        distributed_train_step(input_batch))

                    # Step number, to be used for summary/logging.
                    global_step = optimizer.iterations
                    global_step_value = global_step.numpy()

                    # LR, losses and accuracies summaries.
                    tf.summary.scalar('learning_rate',
                                      optimizer.learning_rate,
                                      step=global_step)
                    tf.summary.scalar('loss/desc/crossentropy',
                                      desc_dist_loss,
                                      step=global_step)
                    tf.summary.scalar('loss/attn/crossentropy',
                                      attn_dist_loss,
                                      step=global_step)
                    if FLAGS.use_autoencoder:
                        tf.summary.scalar('loss/recon/mse',
                                          recon_dist_loss,
                                          step=global_step)

                    tf.summary.scalar('train_accuracy/desc',
                                      desc_train_accuracy.result(),
                                      step=global_step)
                    tf.summary.scalar('train_accuracy/attn',
                                      attn_train_accuracy.result(),
                                      step=global_step)

                    # Summary for number of global steps taken per second.
                    current_time = time.time()
                    if (last_summary_step_value is not None
                            and last_summary_time is not None):
                        tf.summary.scalar(
                            'global_steps_per_sec',
                            (global_step_value - last_summary_step_value) /
                            (current_time - last_summary_time),
                            step=global_step)
                    #if tf.summary.should_record_summaries().numpy():
                    last_summary_step_value = global_step_value
                    last_summary_time = current_time

                    # Print to console if running locally.
                    if FLAGS.debug:
                        if global_step_value % report_interval == 0:
                            print(global_step.numpy())
                            print('desc:', desc_dist_loss.numpy())
                            print('attn:', attn_dist_loss.numpy())

                    # Validate once in {eval_interval*n, n \in N} steps.
                    if global_step_value % eval_interval == 0:
                        for i in range(num_eval_batches):
                            try:
                                validation_batch = next(validation_iter)
                                desc_validation_result, attn_validation_result = (
                                    distributed_validation_step(
                                        validation_batch))
                            except tf.errors.OutOfRangeError:
                                logging.info(
                                    'Stopping eval at batch %d, no more data',
                                    i)
                                break

                        # Log validation results to tensorboard.
                        tf.summary.scalar('validation/desc',
                                          desc_validation_result,
                                          step=global_step)
                        tf.summary.scalar('validation/attn',
                                          attn_validation_result,
                                          step=global_step)

                        logging.info('\nValidation(%f)\n', global_step_value)
                        logging.info(': desc: %f\n',
                                     desc_validation_result.numpy())
                        logging.info(': attn: %f\n',
                                     attn_validation_result.numpy())
                        # Print to console.
                        if FLAGS.debug:
                            print('Validation: desc:',
                                  desc_validation_result.numpy())
                            print('          : attn:',
                                  attn_validation_result.numpy())

                    # Save checkpoint once (each save_interval*n, n \in N) steps, or if
                    # this is the last iteration.
                    # TODO(andrearaujo): save only in one of the two ways. They are
                    # identical, the only difference is that the manager adds some extra
                    # prefixes and variables (eg, optimizer variables).
                    if (global_step_value % save_interval
                            == 0) or (global_step_value >= max_iters):
                        save_path = manager.save(
                            checkpoint_number=global_step_value)
                        logging.info('Saved (%d) at %s', global_step_value,
                                     save_path)

                        file_path = '%s/delf_weights' % FLAGS.logdir
                        model.save_weights(file_path, save_format='tf')
                        logging.info('Saved weights (%d) at %s',
                                     global_step_value, file_path)

                    # Reset metrics for next step.
                    desc_train_accuracy.reset_states()
                    attn_train_accuracy.reset_states()
                    desc_validation_loss.reset_states()
                    attn_validation_loss.reset_states()
                    desc_validation_accuracy.reset_states()
                    attn_validation_accuracy.reset_states()

        logging.info('Finished training for %d steps.', max_iters)
Esempio n. 2
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    #-------------------------------------------------------------
    # Log flags used.
    logging.info('Running training script with\n')
    logging.info('logdir= %s', FLAGS.logdir)
    logging.info('initial_lr= %f', FLAGS.initial_lr)
    logging.info('block3_strides= %s', str(FLAGS.block3_strides))

    # ------------------------------------------------------------
    # Create the strategy.
    strategy = tf.distribute.MirroredStrategy()
    logging.info('Number of devices: %d', strategy.num_replicas_in_sync)
    if FLAGS.debug:
        print('Number of devices:', strategy.num_replicas_in_sync)

    max_iters = FLAGS.max_iters
    global_batch_size = FLAGS.batch_size
    image_size = 321
    num_eval_batches = int(50000 / global_batch_size)
    report_interval = 100
    eval_interval = 1000
    save_interval = 20000

    initial_lr = FLAGS.initial_lr

    clip_val = tf.constant(10.0)

    if FLAGS.debug:
        #tf.config.run_functions_eagerly(True)
        global_batch_size = 32  #4
        max_iters = 100
        num_eval_batches = 4  #1
        save_interval = 20
        report_interval = 1

    # Determine the number of classes based on the version of the dataset.
    gld_info = gld.GoogleLandmarksInfo()
    num_classes = gld_info.num_classes[FLAGS.dataset_version]

    # ------------------------------------------------------------
    # Create the distributed train/validation sets.
    train_dataset = gld.CreateDataset(file_pattern=FLAGS.train_file_pattern,
                                      batch_size=global_batch_size,
                                      image_size=image_size,
                                      augmentation=FLAGS.use_augmentation,
                                      seed=FLAGS.seed)
    validation_dataset = gld.CreateDataset(
        file_pattern=FLAGS.validation_file_pattern,
        batch_size=global_batch_size,
        image_size=image_size,
        augmentation=False,
        seed=FLAGS.seed)

    train_dist_dataset = strategy.experimental_distribute_dataset(
        train_dataset)
    validation_dist_dataset = strategy.experimental_distribute_dataset(
        validation_dataset)

    train_iter = iter(train_dist_dataset)
    validation_iter = iter(validation_dist_dataset)

    # Create a checkpoint directory to store the checkpoints.
    checkpoint_prefix = os.path.join(FLAGS.logdir, 'delf_tf2-ckpt')

    # ------------------------------------------------------------
    # Finally, we do everything in distributed scope.
    with strategy.scope():
        # Compute loss.
        # Set reduction to `none` so we can do the reduction afterwards and divide
        # by global batch size.
        loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.NONE)

        def compute_loss(labels, predictions):
            per_example_loss = loss_object(labels, predictions)
            return tf.nn.compute_average_loss(
                per_example_loss, global_batch_size=global_batch_size)

        # Set up metrics.
        desc_validation_loss = tf.keras.metrics.Mean(
            name='desc_validation_loss')
        attn_validation_loss = tf.keras.metrics.Mean(
            name='attn_validation_loss')
        desc_train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
            name='desc_train_accuracy')
        attn_train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
            name='attn_train_accuracy')
        desc_validation_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
            name='desc_validation_accuracy')
        attn_validation_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
            name='attn_validation_accuracy')

        # ------------------------------------------------------------
        # Setup DELF model and optimizer.
        model = create_model(num_classes)
        logging.info('Model, datasets loaded.\nnum_classes= %d', num_classes)

        optimizer = tf.keras.optimizers.SGD(learning_rate=initial_lr,
                                            momentum=0.9)

        # Setup summary writer.
        summary_writer = tf.summary.create_file_writer(os.path.join(
            FLAGS.logdir, 'train_logs'),
                                                       flush_millis=10000)

        # Setup checkpoint directory.
        checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
        manager = tf.train.CheckpointManager(checkpoint,
                                             "gs://delgckpts/",
                                             max_to_keep=1)

        # ------------------------------------------------------------
        # Train step to run on one GPU.
        def train_step(inputs):
            """Train one batch."""
            images, labels = inputs
            # Temporary workaround to avoid some corrupted labels.
            labels = tf.clip_by_value(labels, 0, model.num_classes - 1)

            global_step = optimizer.iterations
            tf.summary.image('batch_images', (images + 1.0) / 2.0,
                             step=global_step)
            tf.summary.scalar('image_range/max',
                              tf.reduce_max(images),
                              step=global_step)
            tf.summary.scalar('image_range/min',
                              tf.reduce_min(images),
                              step=global_step)

            # TODO(andrearaujo): we should try to unify the backprop into a single
            # function, instead of applying once to descriptor then to attention.
            def _backprop_loss(tape, loss, weights):
                """Backpropogate losses using clipped gradients.

        Args:
          tape: gradient tape.
          loss: scalar Tensor, loss value.
          weights: keras model weights.
        """
                gradients = tape.gradient(loss, weights)
                clipped, _ = tf.clip_by_global_norm(gradients,
                                                    clip_norm=clip_val)
                optimizer.apply_gradients(zip(clipped, weights))

            # Record gradients and loss through backbone.
            with tf.GradientTape() as desc_tape:

                blocks = {}
                prelogits = model.backbone(images,
                                           intermediates_dict=blocks,
                                           training=True)

                # Report sparsity.
                activations_zero_fractions = {
                    'sparsity/%s' % k: tf.nn.zero_fraction(v)
                    for k, v in blocks.items()
                }
                for k, v in activations_zero_fractions.items():
                    tf.summary.scalar(k, v, step=global_step)

                # Apply descriptor classifier and report scale factor.
                if FLAGS.delg_global_features:
                    logits = model.desc_classification(prelogits, labels)
                    tf.summary.scalar('desc/scale_factor',
                                      model.scale_factor,
                                      step=global_step)
                else:
                    logits = model.desc_classification(prelogits)

                desc_loss = compute_loss(labels, logits)

            # Backprop only through backbone weights.
            _backprop_loss(desc_tape, desc_loss, model.desc_trainable_weights)

            # Record descriptor train accuracy.
            _record_accuracy(desc_train_accuracy, logits, labels)

            # Record gradients and loss through attention block.
            with tf.GradientTape() as attn_tape:
                block3 = blocks['block3']  # pytype: disable=key-error

                # Stopping gradients according to DELG paper:
                # (https://arxiv.org/abs/2001.05027).
                block3 = tf.stop_gradient(block3)

                prelogits, scores, _ = model.attention(block3, training=True)
                _attention_summaries(scores, global_step)

                # Apply attention block classifier.
                logits = model.attn_classification(prelogits)

                attn_loss = compute_loss(labels, logits)

            # Backprop only through attention weights.
            _backprop_loss(attn_tape, attn_loss, model.attn_trainable_weights)

            # Record attention train accuracy.
            _record_accuracy(attn_train_accuracy, logits, labels)

            return desc_loss, attn_loss

        # ------------------------------------------------------------
        def validation_step(inputs):
            """Validate one batch."""
            images, labels = inputs
            labels = tf.clip_by_value(labels, 0, model.num_classes - 1)

            # Get descriptor predictions.
            blocks = {}
            prelogits = model.backbone(images,
                                       intermediates_dict=blocks,
                                       training=False)
            if FLAGS.delg_global_features:
                logits = model.desc_classification(prelogits,
                                                   labels,
                                                   training=False)
            else:
                logits = model.desc_classification(prelogits, training=False)
            softmax_probabilities = tf.keras.layers.Softmax()(logits)

            validation_loss = loss_object(labels, logits)
            desc_validation_loss.update_state(validation_loss)
            desc_validation_accuracy.update_state(labels,
                                                  softmax_probabilities)

            # Get attention predictions.
            block3 = blocks['block3']  # pytype: disable=key-error
            prelogits, _, _ = model.attention(block3, training=False)

            logits = model.attn_classification(prelogits, training=False)
            softmax_probabilities = tf.keras.layers.Softmax()(logits)

            validation_loss = loss_object(labels, logits)
            attn_validation_loss.update_state(validation_loss)
            attn_validation_accuracy.update_state(labels,
                                                  softmax_probabilities)

            return desc_validation_accuracy.result(
            ), attn_validation_accuracy.result()

        # `run` replicates the provided computation and runs it
        # with the distributed input.
        @tf.function
        def distributed_train_step(dataset_inputs):
            """Get the actual losses."""
            # Each (desc, attn) is a list of 3 losses - crossentropy, reg, total.
            desc_per_replica_loss, attn_per_replica_loss = (strategy.run(
                train_step, args=(dataset_inputs, )))

            # Reduce over the replicas.
            desc_global_loss = strategy.reduce(tf.distribute.ReduceOp.SUM,
                                               desc_per_replica_loss,
                                               axis=None)
            attn_global_loss = strategy.reduce(tf.distribute.ReduceOp.SUM,
                                               attn_per_replica_loss,
                                               axis=None)

            return desc_global_loss, attn_global_loss

        @tf.function
        def distributed_validation_step(dataset_inputs):
            return strategy.run(validation_step, args=(dataset_inputs, ))

        # ------------------------------------------------------------
        # *** TRAIN LOOP ***
        with summary_writer.as_default():
            with tf.summary.record_if(
                    tf.math.equal(0, optimizer.iterations % report_interval)):

                # TODO(dananghel): try to load pretrained weights at backbone creation.
                # Load pretrained weights for ResNet50 trained on ImageNet.
                if FLAGS.imagenet_checkpoint is not None:
                    logging.info(
                        'Attempting to load ImageNet pretrained weights.')
                    input_batch = next(train_iter)
                    _, _ = distributed_train_step(input_batch)
                    model.backbone.restore_weights(FLAGS.imagenet_checkpoint)
                    logging.info('Done.')
                else:
                    logging.info('Skip loading ImageNet pretrained weights.')
                #if FLAGS.debug:
                #  model.backbone.log_weights()

                global_step_value = optimizer.iterations.numpy()
                while global_step_value < max_iters:

                    # input_batch : images(b, h, w, c), labels(b,).
                    try:
                        input_batch = next(train_iter)
                    except tf.errors.OutOfRangeError:
                        # Break if we run out of data in the dataset.
                        logging.info(
                            'Stopping training at global step %d, no more data',
                            global_step_value)
                        break

                    # Set learning rate for optimizer to use.
                    global_step = optimizer.iterations
                    global_step_value = global_step.numpy()

                    learning_rate = _learning_rate_schedule(
                        global_step_value, max_iters, initial_lr)
                    optimizer.learning_rate = learning_rate
                    tf.summary.scalar('learning_rate',
                                      optimizer.learning_rate,
                                      step=global_step)

                    # Run the training step over num_gpu gpus.
                    desc_dist_loss, attn_dist_loss = distributed_train_step(
                        input_batch)

                    # Log losses and accuracies to tensorboard.
                    tf.summary.scalar('loss/desc/crossentropy',
                                      desc_dist_loss,
                                      step=global_step)
                    tf.summary.scalar('loss/attn/crossentropy',
                                      attn_dist_loss,
                                      step=global_step)
                    tf.summary.scalar('train_accuracy/desc',
                                      desc_train_accuracy.result(),
                                      step=global_step)
                    tf.summary.scalar('train_accuracy/attn',
                                      attn_train_accuracy.result(),
                                      step=global_step)

                    # Print to console if running locally.
                    if FLAGS.debug:
                        if global_step_value % report_interval == 0:
                            print(global_step.numpy())
                            print('desc:', desc_dist_loss.numpy())
                            print('attn:', attn_dist_loss.numpy())

                    # Validate once in {eval_interval*n, n \in N} steps.
                    if global_step_value % eval_interval == 0:
                        for i in range(num_eval_batches):
                            try:
                                validation_batch = next(validation_iter)
                                desc_validation_result, attn_validation_result = (
                                    distributed_validation_step(
                                        validation_batch))
                            except tf.errors.OutOfRangeError:
                                logging.info(
                                    'Stopping eval at batch %d, no more data',
                                    i)
                                break

                        # Log validation results to tensorboard.
                        tf.summary.scalar('validation/desc',
                                          desc_validation_result,
                                          step=global_step)
                        tf.summary.scalar('validation/attn',
                                          attn_validation_result,
                                          step=global_step)

                        logging.info('\nValidation(%f)\n', global_step_value)
                        logging.info(': desc: %f\n',
                                     desc_validation_result.numpy())
                        logging.info(': attn: %f\n',
                                     attn_validation_result.numpy())
                        # Print to console.
                        if FLAGS.debug:
                            print('Validation: desc:',
                                  desc_validation_result.numpy())
                            print('          : attn:',
                                  attn_validation_result.numpy())

                    # Save checkpoint once (each save_interval*n, n \in N) steps.
                    # TODO(andrearaujo): save only in one of the two ways. They are
                    # identical, the only difference is that the manager adds some extra
                    # prefixes and variables (eg, optimizer variables).
                    if global_step_value % save_interval == 0:
                        save_path = manager.save()
                        logging.info('Saved (%d) at %s', global_step_value,
                                     save_path)

                        file_path = '%s/delf_weights' % FLAGS.logdir
                        model.save_weights(file_path, save_format='tf')
                        logging.info('Saved weights (%d) at %s',
                                     global_step_value, file_path)

                    # Reset metrics for next step.
                    desc_train_accuracy.reset_states()
                    attn_train_accuracy.reset_states()
                    desc_validation_loss.reset_states()
                    attn_validation_loss.reset_states()
                    desc_validation_accuracy.reset_states()
                    attn_validation_accuracy.reset_states()

                    if global_step.numpy() > max_iters:
                        break

        logging.info('Finished training for %d steps.', max_iters)
    file_path = '%s/delf_saved_model' % FLAGS.logdir
    #model.save(file_path)
    tf.saved_model.save(model, "gs://delgckpts/delf_saved_model")