Beispiel #1
0
def launch_tensorboard(log_dir):
  """Runs tensorboard with the given `log_dir`.

  Args:
    log_dir: String, directory to launch tensorboard in.
  """
  tensorboard = program.TensorBoard()
  tensorboard.configure(argv=[None, '--logdir', log_dir])
  url = tensorboard.launch()
  global_features_utils.debug_and_log("Launching Tensorboard: {}".format(url))
Beispiel #2
0
def extract_global_descriptors_from_list(net,
                                         images,
                                         image_size,
                                         bounding_boxes=None,
                                         scales=[1.],
                                         multi_scale_power=1.,
                                         print_freq=10):
    """Extracting global descriptors from a list of images.

  Args:
    net: Model object, network for the forward pass.
    images: Absolute image paths as strings.
    image_size: Integer, defines the maximum size of longer image side.
    bounding_boxes: List of (x1,y1,x2,y2) tuples to crop the query images.
    scales: List of float scales.
    multi_scale_power: Float, multi-scale normalization power parameter.
    print_freq: Printing frequency for debugging.

  Returns:
    descriptors: Global descriptors for the input images.
  """
    # Creating dataset loader.
    data = generic_dataset.ImagesFromList(root='',
                                          image_paths=images,
                                          imsize=image_size,
                                          bounding_boxes=bounding_boxes)

    def _data_gen():
        return (inst for inst in data)

    loader = tf.data.Dataset.from_generator(_data_gen,
                                            output_types=(tf.float32))
    loader = loader.batch(1)

    # Extracting vectors.
    descriptors = tf.zeros((0, net.meta['outputdim']))
    for i, input in enumerate(loader):
        if len(scales) == 1 and scales[0] == 1:
            descriptors = tf.concat([descriptors, net(input)], 0)
        else:
            descriptors = tf.concat([
                descriptors,
                extract_multi_scale_descriptor(net, input, scales,
                                               multi_scale_power)
            ], 0)

        if (i + 1) % print_freq == 0 or (i + 1) == len(images):
            global_features_utils.debug_and_log('\r>>>> {}/{} done...'.format(
                (i + 1), len(images)),
                                                debug_on_the_same_line=True)
    global_features_utils.debug_and_log('', log=False)

    descriptors = tf.transpose(descriptors, perm=[1, 0])
    return descriptors
Beispiel #3
0
    def __init__(self,
                 architecture='ResNet101',
                 pooling='gem',
                 whitening=False,
                 pretrained=True,
                 data_root=''):
        """GlobalFeatureNet network initialization.

    Args:
      architecture: Network backbone.
      pooling: Pooling method used 'mac'/'spoc'/'gem'.
      whitening: Bool, whether to use whitening.
      pretrained: Bool, whether to initialize the network with the weights
        pretrained on ImageNet.
      data_root: String, path to the data folder where the precomputed
        whitening is/will be saved in case `whitening` is True.

    Raises:
      ValueError: If `architecture` is not supported.
    """
        if architecture not in _OUTPUT_DIM.keys():
            raise ValueError(
                "Architecture {} is not supported.".format(architecture))

        super(GlobalFeatureNet, self).__init__()

        # Get standard output dimensionality size.
        dim = _OUTPUT_DIM[architecture]

        if pretrained:
            # Initialize with network pretrained on imagenet.
            net_in = getattr(tf.keras.applications,
                             architecture)(include_top=False,
                                           weights="imagenet")
        else:
            # Initialize with random weights.
            net_in = getattr(tf.keras.applications,
                             architecture)(include_top=False, weights=None)

        # Initialize `feature_extractor`. Take only convolutions for
        # `feature_extractor`, always end with ReLU to make last activations
        # non-negative.
        if architecture.lower().startswith('densenet'):
            tmp_model = tf.keras.Sequential()
            tmp_model.add(net_in)
            net_in = tmp_model
            net_in.add(tf.keras.layers.ReLU())

        # Initialize pooling.
        self.pool = _POOLING[pooling]()

        # Initialize whitening.
        if whitening:
            if pretrained and architecture in _WHITENING_CONFIG:
                # If precomputed whitening for the architecture exists,
                # the fully-connected layer is going to be initialized according to
                # the precomputed layer configuration.
                global_features_utils.debug_and_log(
                    ">> {}: for '{}' custom computed whitening '{}' is used.".
                    format(os.getcwd(), architecture,
                           os.path.basename(_WHITENING_CONFIG[architecture])))
                # The layer configuration is downloaded to the `data_root` folder.
                whiten_dir = os.path.join(data_root, architecture)
                path = tf.keras.utils.get_file(
                    fname=whiten_dir, origin=_WHITENING_CONFIG[architecture])
                # Whitening configuration is loaded.
                with tf.io.gfile.GFile(path, 'rb') as learned_whitening_file:
                    whitening_config = pickle.load(learned_whitening_file)
                # Whitening layer is initialized according to the configuration.
                self.whiten = tf.keras.layers.Dense.from_config(
                    whitening_config)
            else:
                # In case if no precomputed whitening exists for the chosen
                # architecture, the fully-connected whitening layer is initialized
                # with the random weights.
                self.whiten = tf.keras.layers.Dense(dim,
                                                    activation=None,
                                                    use_bias=True)
                global_features_utils.debug_and_log(
                    ">> There is either no whitening computed for the "
                    "used network architecture or pretrained is False,"
                    " random weights are used.")
        else:
            self.whiten = None

        # Create meta information to be stored in the network.
        self.meta = {
            'architecture': architecture,
            'pooling': pooling,
            'whitening': whitening,
            'outputdim': dim
        }

        self.feature_extractor = net_in
        self.normalize = normalization.L2Normalization()
    def create_epoch_tuples(self, net):
        """Creates epoch tuples with the hard-negative re-mining.

    Negative examples are selected from clusters different than the cluster
    of the query image, as the clusters are ideally non-overlaping. For
    every query image we choose  hard-negatives, that is, non-matching images
    with the most similar descriptor. Hard-negatives depend on the current
    CNN parameters. K-nearest neighbors from all non-matching images are
    selected. Query images are selected randomly. Positives examples are
    fixed for the related query image during the whole training process.

    Args:
      net: Model, network to be used for negative re-mining.

    Raises:
      ValueError: If the pool_size is smaller than the number of negative
        images per tuple.

    Returns:
      avg_l2: Float, average negative L2-distance.
    """
        self._n = 0

        if self._num_negatives < self._pool_size:
            raise ValueError(
                "Unable to create epoch tuples. Negative pool_size "
                "should be larger than the number of negative images "
                "per tuple.")

        global_features_utils.debug_and_log(
            '>> Creating tuples for an epoch of {}-{}...'.format(
                self._name, self._mode), True)
        global_features_utils.debug_and_log(">> Used network: ", True)
        global_features_utils.debug_and_log(net.meta_repr(), True)

        ## Selecting queries.
        # Draw `num_queries` random queries for the tuples.
        idx_list = np.arange(len(self._query_pool))
        np.random.shuffle(idx_list)
        idxs2query_pool = idx_list[:self._num_queries]
        self._qidxs = [self._query_pool[i] for i in idxs2query_pool]

        ## Selecting positive pairs.
        # Positives examples are fixed for each query during the whole training
        # process.
        self._pidxs = [self._positive_pool[i] for i in idxs2query_pool]

        ## Selecting negative pairs.
        # If `num_negatives` = 0 create dummy nidxs.
        # Useful when only positives used for training.
        if self._num_negatives == 0:
            self._nidxs = [[] for _ in range(len(self._qidxs))]
            return 0

        # Draw pool_size random images for pool of negatives images.
        neg_idx_list = np.arange(len(self.images))
        np.random.shuffle(neg_idx_list)
        neg_images_idxs = neg_idx_list[:self._pool_size]

        global_features_utils.debug_and_log(
            '>> Extracting descriptors for query images...', debug=True)

        img_list = self._img_names_to_full_path(
            [self.images[i] for i in self._qidxs])
        qvecs = global_model.extract_global_descriptors_from_list(
            net,
            images=img_list,
            image_size=self._imsize,
            print_freq=self._print_freq)

        global_features_utils.debug_and_log(
            '>> Extracting descriptors for negative pool...', debug=True)

        poolvecs = global_model.extract_global_descriptors_from_list(
            net,
            images=self._img_names_to_full_path(
                [self.images[i] for i in neg_images_idxs]),
            image_size=self._imsize,
            print_freq=self._print_freq)

        global_features_utils.debug_and_log(
            '>> Searching for hard negatives...', debug=True)

        # Compute dot product scores and ranks.
        scores = tf.linalg.matmul(poolvecs, qvecs, transpose_a=True)
        ranks = tf.argsort(scores, axis=0, direction='DESCENDING')

        sum_ndist = 0.
        n_ndist = 0.

        # Selection of negative examples.
        self._nidxs = []

        for q, qidx in enumerate(self._qidxs):
            # We are not using the query cluster, those images are potentially
            # positive.
            qcluster = self._clusters[qidx]
            clusters = [qcluster]
            nidxs = []
            rank = 0

            while len(nidxs) < self._num_negatives:
                if rank >= tf.shape(ranks)[0]:
                    raise ValueError(
                        "Unable to create epoch tuples. Number of required "
                        "negative images is larger than the number of "
                        "clusters in the dataset.")
                potential = neg_images_idxs[ranks[rank, q]]
                # Take at most one image from the same cluster.
                if not self._clusters[potential] in clusters:
                    nidxs.append(potential)
                    clusters.append(self._clusters[potential])
                    dist = tf.norm(qvecs[:, q] - poolvecs[:, ranks[rank, q]],
                                   axis=0).numpy()
                    sum_ndist += dist
                    n_ndist += 1
                rank += 1

            self._nidxs.append(nidxs)

        global_features_utils.debug_and_log(
            '>> Average negative l2-distance: {:.2f}'.format(sum_ndist /
                                                             n_ndist))

        # Return average negative L2-distance.
        return sum_ndist / n_ndist
def train_val_one_epoch(loader,
                        model,
                        criterion,
                        optimizer,
                        epoch,
                        train=True,
                        batch_size=5,
                        query_size=2000,
                        neg_num=5,
                        update_every=1,
                        debug=False):
    """Executes either training or validation step based on `train` value.

  Args:
    loader: Training/validation iterable dataset.
    model: Network to train/validate.
    criterion: Loss function.
    optimizer: Network optimizer.
    epoch: Integer, epoch number.
    train: Bool, specifies training or validation phase.
    batch_size: Integer, number of (q,p,n1,...,nN) tuples in a mini-batch.
    query_size: Integer, number of queries randomly drawn per one training
      epoch.
    neg_num: Integer, number of negatives per a tuple.
    update_every: Integer, update model weights every N batches, used to
      handle relatively large batches batch_size effectively becomes
      update_every x batch_size.
    debug: Bool, whether debug mode is used.

  Returns:
    average_epoch_loss: Average epoch loss.
  """
    batch_time = global_features_utils.AverageMeter()
    data_time = global_features_utils.AverageMeter()
    losses = global_features_utils.AverageMeter()

    # Retrieve all trainable variables we defined in the graph.
    tvs = model.trainable_variables
    accum_grads = [tf.zeros_like(tv.read_value()) for tv in tvs]

    end = time.time()
    batch_num = 0
    print_frequency = 10
    all_batch_num = query_size // batch_size
    state = 'Train' if train else 'Val'
    global_features_utils.debug_and_log('>> {} step:'.format(state))

    # For every batch in the dataset; Stops when all batches in the dataset have
    # been processed.
    while True:
        data_time.update(time.time() - end)

        if train:
            try:
                # Train on one batch.
                # Each image in the batch is loaded into memory consecutively.
                for _ in range(batch_size):
                    # Because the images are not necessarily of the same size, we can't
                    # set the batch size with .batch().
                    batch = loader.get_next()
                    input_tuple = batch[0:-1]
                    target_tuple = batch[-1]

                    loss_value, grads = _compute_loss_and_gradient(
                        criterion, model, input_tuple, target_tuple, neg_num)
                    losses.update(loss_value)
                    # Accumulate gradients.
                    accum_grads += grads

                # Perform weight update if required.
                if (batch_num + 1) % update_every == 0 or (batch_num +
                                                           1) == all_batch_num:
                    # Do one step for multiple batches. Accumulated gradients are
                    # used.
                    optimizer.apply_gradients(
                        zip(accum_grads, model.trainable_variables))
                    accum_grads = [
                        tf.zeros_like(tv.read_value()) for tv in tvs
                    ]
            # We break when we run out of range, i.e., we exhausted all dataset
            # images.
            except tf.errors.OutOfRangeError:
                break

        else:
            # Validate one batch.
            # We load full batch into memory.
            input = []
            target = []
            try:
                for _ in range(batch_size):
                    # Because the images are not necessarily of the same size, we can't
                    # set the batch size with .batch().
                    batch = loader.get_next()
                    input.append(batch[0:-1])
                    target.append(batch[-1])
            # We break when we run out of range, i.e., we exhausted all dataset
            # images.
            except tf.errors.OutOfRangeError:
                break

            descriptors = tf.zeros(shape=(0, model.meta['outputdim']),
                                   dtype=tf.float32)

            for input_tuple in input:
                for img in input_tuple:
                    # Compute the global descriptor vector.
                    model_out = model(tf.expand_dims(img, axis=0),
                                      training=False)
                    descriptors = tf.concat([descriptors, model_out], 0)

            # No need to reduce memory consumption (no backward pass):
            # Compute loss for the full batch.
            queries = descriptors[target == -1]
            positives = descriptors[target == 1]
            negatives = descriptors[target == 0]
            negatives = tf.reshape(
                negatives,
                [tf.shape(queries)[0], neg_num, model.meta['outputdim']])
            loss = criterion(queries, positives, negatives)

            # Record loss.
            losses.update(loss / batch_size, batch_size)

        # Measure elapsed time.
        batch_time.update(time.time() - end)
        end = time.time()

        # Record immediate loss and elapsed time.
        if debug and ((batch_num + 1) % print_frequency == 0 or batch_num == 0
                      or (batch_num + 1) == all_batch_num):
            global_features_utils.debug_and_log(
                '>> {0}: [{1} epoch][{2}/{3} batch]\t Time val: {'
                'batch_time.val:.3f} '
                '(Batch Time avg: {batch_time.avg:.3f})\t Data {'
                'data_time.val:.3f} ('
                'Time avg: {data_time.avg:.3f})\t Immediate loss value: {'
                'loss.val:.4f} '
                '(Loss avg: {loss.avg:.4f})'.format(state,
                                                    epoch,
                                                    batch_num + 1,
                                                    all_batch_num,
                                                    batch_time=batch_time,
                                                    data_time=data_time,
                                                    loss=losses),
                debug=True,
                log=False)
        batch_num += 1

    return losses.avg
def test_retrieval(datasets,
                   net,
                   epoch,
                   writer=None,
                   model_directory=None,
                   precompute_whitening=None,
                   data_root='data',
                   multiscale=[1.],
                   test_image_size=1024):
    """Testing step.

  Evaluates the network on the provided test datasets by computing single-scale
  mAP for easy/medium/hard cases. If `writer` is specified, saves the mAP
  values in a tensorboard supported format.

  Args:
    datasets: List of dataset names for model testing (from
      `_TEST_DATASET_NAMES`).
    net: Network to evaluate.
    epoch: Integer, epoch number.
    writer: Tensorboard writer.
    model_directory: String, path to the model directory.
    precompute_whitening: Dataset used to learn whitening. If no
      precomputation required, then `None`. Only 'retrieval-SfM-30k' and
      'retrieval-SfM-120k' datasets are supported for whitening pre-computation.
    data_root: Absolute path to the data folder.
    multiscale: List of scales for multiscale testing.
    test_image_size: Integer, maximum size of the test images.
  """
    global_features_utils.debug_and_log(">> Testing step:")
    global_features_utils.debug_and_log(
        '>> Evaluating network on test datasets...')

    # Precompute whitening.
    if precompute_whitening is not None:

        # If whitening already precomputed, load it and skip the computations.
        filename = os.path.join(
            model_directory, 'learned_whitening_mP_{}_epoch.pkl'.format(epoch))
        filename_layer = os.path.join(
            model_directory,
            'learned_whitening_layer_config_{}_epoch.pkl'.format(epoch))

        if tf.io.gfile.exists(filename):
            global_features_utils.debug_and_log(
                '>> {}: Whitening for this epoch is already precomputed. '
                'Loading...'.format(precompute_whitening))
            with tf.io.gfile.GFile(filename, 'rb') as learned_whitening_file:
                learned_whitening = pickle.load(learned_whitening_file)

        else:
            start = time.time()
            global_features_utils.debug_and_log(
                '>> {}: Learning whitening...'.format(precompute_whitening))

            # Loading db.
            db_root = os.path.join(data_root, 'train', precompute_whitening)
            ims_root = os.path.join(db_root, 'ims')
            db_filename = os.path.join(
                db_root, '{}-whiten.pkl'.format(precompute_whitening))
            with tf.io.gfile.GFile(db_filename, 'rb') as f:
                db = pickle.load(f)
            images = [
                sfm120k.id2filename(db['cids'][i], ims_root)
                for i in range(len(db['cids']))
            ]

            # Extract whitening vectors.
            global_features_utils.debug_and_log(
                '>> {}: Extracting...'.format(precompute_whitening))
            wvecs = global_model.extract_global_descriptors_from_list(
                net, images, test_image_size)

            # Learning whitening.
            global_features_utils.debug_and_log(
                '>> {}: Learning...'.format(precompute_whitening))
            wvecs = wvecs.numpy()
            mean_vector, projection_matrix = whiten.whitenlearn(
                wvecs, db['qidxs'], db['pidxs'])
            learned_whitening = {'m': mean_vector, 'P': projection_matrix}

            global_features_utils.debug_and_log(
                '>> {}: Elapsed time: {}'.format(
                    precompute_whitening,
                    global_features_utils.htime(time.time() - start)))
            # Save learned_whitening parameters for a later use.
            with tf.io.gfile.GFile(filename, 'wb') as learned_whitening_file:
                pickle.dump(learned_whitening, learned_whitening_file)

            # Saving whitening as a layer.
            bias = -np.dot(mean_vector.T, projection_matrix.T)
            whitening_layer = tf.keras.layers.Dense(
                net.meta['outputdim'],
                activation=None,
                use_bias=True,
                kernel_initializer=tf.keras.initializers.Constant(
                    projection_matrix.T),
                bias_initializer=tf.keras.initializers.Constant(bias))
            with tf.io.gfile.GFile(filename_layer,
                                   'wb') as learned_whitening_file:
                pickle.dump(whitening_layer.get_config(),
                            learned_whitening_file)
    else:
        learned_whitening = None

    # Evaluate on test datasets.
    for dataset in datasets:
        start = time.time()

        # Prepare config structure for the test dataset.
        cfg = test_dataset.CreateConfigForTestDataset(dataset,
                                                      os.path.join(data_root))
        images = [cfg['im_fname'](cfg, i) for i in range(cfg['n'])]
        qimages = [cfg['qim_fname'](cfg, i) for i in range(cfg['nq'])]
        bounding_boxes = [
            tuple(cfg['gnd'][i]['bbx']) for i in range(cfg['nq'])
        ]

        # Extract database and query vectors.
        global_features_utils.debug_and_log(
            '>> {}: Extracting database images...'.format(dataset))
        vecs = global_model.extract_global_descriptors_from_list(
            net, images, test_image_size, scales=multiscale)
        global_features_utils.debug_and_log(
            '>> {}: Extracting query images...'.format(dataset))
        qvecs = global_model.extract_global_descriptors_from_list(
            net, qimages, test_image_size, bounding_boxes, scales=multiscale)

        global_features_utils.debug_and_log(
            '>> {}: Evaluating...'.format(dataset))

        # Convert the obtained descriptors to numpy.
        vecs = vecs.numpy()
        qvecs = qvecs.numpy()

        # Search, rank and print test set metrics.
        _calculate_metrics_and_export_to_tensorboard(vecs,
                                                     qvecs,
                                                     dataset,
                                                     cfg,
                                                     writer,
                                                     epoch,
                                                     whiten=False)

        if learned_whitening is not None:
            # Whiten the vectors.
            mean_vector = learned_whitening['m']
            projection_matrix = learned_whitening['P']
            vecs_lw = whiten.whitenapply(vecs, mean_vector, projection_matrix)
            qvecs_lw = whiten.whitenapply(qvecs, mean_vector,
                                          projection_matrix)

            # Search, rank, and print.
            _calculate_metrics_and_export_to_tensorboard(vecs_lw,
                                                         qvecs_lw,
                                                         dataset,
                                                         cfg,
                                                         writer,
                                                         epoch,
                                                         whiten=True)

        global_features_utils.debug_and_log('>> {}: Elapsed time: {}'.format(
            dataset, global_features_utils.htime(time.time() - start)))
Beispiel #7
0
def main(argv):
  if len(argv) > 1:
    raise RuntimeError('Too many command-line arguments.')

  # Manually check if there are unknown test datasets and if the dataset
  # ground truth files are downloaded.
  for dataset in FLAGS.test_datasets:
    if dataset not in _TEST_DATASET_NAMES:
      raise ValueError('Unsupported or unknown test dataset: {}.'.format(
              dataset))

    test_data_config = os.path.join(FLAGS.data_root,
                                    'gnd_{}.pkl'.format(dataset))
    if not tf.io.gfile.exists(test_data_config):
      raise ValueError(
              '{} ground truth file at {} not found. Please download it '
              'according to '
              'the DELG instructions.'.format(dataset, FLAGS.data_root))

  # Check if train dataset is downloaded and download it if not found.
  dataset_download.download_train(FLAGS.data_root)

  # Creating model export directory if it does not exist.
  model_directory = global_features_utils.create_model_directory(
          FLAGS.training_dataset, FLAGS.arch, FLAGS.pool, FLAGS.whitening,
          FLAGS.pretrained, FLAGS.loss, FLAGS.loss_margin, FLAGS.optimizer,
          FLAGS.lr, FLAGS.weight_decay, FLAGS.neg_num, FLAGS.query_size,
          FLAGS.pool_size, FLAGS.batch_size, FLAGS.update_every,
          FLAGS.image_size, FLAGS.directory)

  # Setting up logging directory, same as where the model is stored.
  logging.get_absl_handler().use_absl_log_file('absl_logging', model_directory)

  # Set cuda visible device.
  os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_id
  global_features_utils.debug_and_log('>> Num GPUs Available: {}'.format(
          len(tf.config.experimental.list_physical_devices('GPU'))),
          FLAGS.debug)

  # Set random seeds.
  tf.random.set_seed(0)
  np.random.seed(0)

  # Initialize the model.
  if FLAGS.pretrained:
    global_features_utils.debug_and_log(
            '>> Using pre-trained model \'{}\''.format(FLAGS.arch))
  else:
    global_features_utils.debug_and_log(
            '>> Using model from scratch (random weights) \'{}\'.'.format(
                    FLAGS.arch))

  model_params = {'architecture': FLAGS.arch, 'pooling': FLAGS.pool,
                  'whitening': FLAGS.whitening, 'pretrained': FLAGS.pretrained,
                  'data_root': FLAGS.data_root}
  model = global_model.GlobalFeatureNet(**model_params)

  # Freeze running mean and std in batch normalization layers.
  # We do training one image at a time to improve memory requirements of
  # the network; therefore, the computed statistics would not be per a
  # batch. Instead, we choose freezing - setting the parameters of all
  # batch norm layers in the network to non-trainable (i.e., using original
  # imagenet statistics).
  for layer in model.feature_extractor.layers:
    if isinstance(layer, tf.keras.layers.BatchNormalization):
      layer.trainable = False

  global_features_utils.debug_and_log('>> Network initialized.')

  global_features_utils.debug_and_log('>> Loss: {}.'.format(FLAGS.loss))
  # Define the loss function.
  if FLAGS.loss == 'contrastive':
    criterion = ranking_losses.ContrastiveLoss(margin=FLAGS.loss_margin)
  elif FLAGS.loss == 'triplet':
    criterion = ranking_losses.TripletLoss(margin=FLAGS.loss_margin)
  else:
    raise ValueError('Loss {} not available.'.format(FLAGS.loss))

  # Defining parameters for the training.
  # When pre-computing whitening, we run evaluation before the network training
  # and the `start_epoch` is set to 0. In other cases, we start from epoch 1.
  start_epoch = 1
  exp_decay = math.exp(-0.01)
  decay_steps = FLAGS.query_size / FLAGS.batch_size

  # Define learning rate decay schedule.
  lr_scheduler = tf.keras.optimizers.schedules.ExponentialDecay(
          initial_learning_rate=FLAGS.lr,
          decay_steps=decay_steps,
          decay_rate=exp_decay)

  # Define the optimizer.
  if FLAGS.optimizer == 'sgd':
    opt = tfa.optimizers.extend_with_decoupled_weight_decay(
            tf.keras.optimizers.SGD)
    optimizer = opt(weight_decay=FLAGS.weight_decay,
                    learning_rate=lr_scheduler, momentum=FLAGS.momentum)
  elif FLAGS.optimizer == 'adam':
    opt = tfa.optimizers.extend_with_decoupled_weight_decay(
            tf.keras.optimizers.Adam)
    optimizer = opt(weight_decay=FLAGS.weight_decay, learning_rate=lr_scheduler)
  else:
    raise ValueError('Optimizer {} not available.'.format(FLAGS.optimizer))

  # Initializing logging.
  writer = tf.summary.create_file_writer(model_directory)
  tf.summary.experimental.set_step(1)

  # Setting up the checkpoint manager.
  checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
  manager = tf.train.CheckpointManager(
          checkpoint,
          model_directory,
          max_to_keep=10,
          keep_checkpoint_every_n_hours=3)
  if FLAGS.resume:
    # Restores the checkpoint, if existing.
    global_features_utils.debug_and_log('>> Continuing from a checkpoint.')
    checkpoint.restore(manager.latest_checkpoint)

  # Launching tensorboard if required.
  if FLAGS.launch_tensorboard:
    tensorboard = tf.keras.callbacks.TensorBoard(model_directory)
    tensorboard.set_model(model=model)
    tensorboard_utils.launch_tensorboard(log_dir=model_directory)

  # Log flags used.
  global_features_utils.debug_and_log('>> Running training script with:')
  global_features_utils.debug_and_log('>> logdir = {}'.format(model_directory))

  if FLAGS.training_dataset.startswith('retrieval-SfM-120k'):
    train_dataset = sfm120k.CreateDataset(
            data_root=FLAGS.data_root,
            mode='train',
            imsize=FLAGS.image_size,
            num_negatives=FLAGS.neg_num,
            num_queries=FLAGS.query_size,
            pool_size=FLAGS.pool_size
    )
    if FLAGS.validation_type is not None:
      val_dataset = sfm120k.CreateDataset(
              data_root=FLAGS.data_root,
              mode='val',
              imsize=FLAGS.image_size,
              num_negatives=FLAGS.neg_num,
              num_queries=float('Inf'),
              pool_size=float('Inf'),
              eccv2020=True if FLAGS.validation_type == 'eccv2020' else False
      )

  train_dataset_output_types = [tf.float32 for i in range(2 + FLAGS.neg_num)]
  train_dataset_output_types.append(tf.int32)

  global_features_utils.debug_and_log(
          '>> Training the {} network'.format(model_directory))
  global_features_utils.debug_and_log('>> GPU ids: {}'.format(FLAGS.gpu_id))

  with writer.as_default():

    # Precompute whitening if needed.
    if FLAGS.precompute_whitening is not None:
      epoch = 0
      train_utils.test_retrieval(
              FLAGS.test_datasets, model, writer=writer,
              epoch=epoch, model_directory=model_directory,
              precompute_whitening=FLAGS.precompute_whitening,
              data_root=FLAGS.data_root,
              multiscale=FLAGS.multiscale)

    for epoch in range(start_epoch, FLAGS.epochs + 1):
      # Set manual seeds per epoch.
      np.random.seed(epoch)
      tf.random.set_seed(epoch)

      # Find hard-negatives.
      # While hard-positive examples are fixed during the whole training
      # process and are randomly chosen from every epoch; hard-negatives
      # depend on the current CNN parameters and are re-mined once per epoch.
      avg_neg_distance = train_dataset.create_epoch_tuples(model)

      def _train_gen():
        return (inst for inst in train_dataset)

      train_loader = tf.data.Dataset.from_generator(
              _train_gen,
              output_types=tuple(train_dataset_output_types))

      loss = train_utils.train_val_one_epoch(
              loader=iter(train_loader), model=model,
              criterion=criterion, optimizer=optimizer, epoch=epoch,
              batch_size=FLAGS.batch_size, query_size=FLAGS.query_size,
              neg_num=FLAGS.neg_num, update_every=FLAGS.update_every,
              debug=FLAGS.debug)

      # Write a scalar summary.
      tf.summary.scalar('train_epoch_loss', loss, step=epoch)
      # Forces summary writer to send any buffered data to storage.
      writer.flush()

      # Evaluate on validation set.
      if FLAGS.validation_type is not None and (epoch % FLAGS.test_freq == 0 or
                                                epoch == 1):
        avg_neg_distance = val_dataset.create_epoch_tuples(model,
                                                           model_directory)

        def _val_gen():
          return (inst for inst in val_dataset)

        val_loader = tf.data.Dataset.from_generator(
                _val_gen, output_types=tuple(train_dataset_output_types))

        loss = train_utils.train_val_one_epoch(
                loader=iter(val_loader), model=model,
                criterion=criterion, optimizer=None,
                epoch=epoch, train=False, batch_size=FLAGS.batch_size,
                query_size=FLAGS.query_size, neg_num=FLAGS.neg_num,
                update_every=FLAGS.update_every, debug=FLAGS.debug)
        tf.summary.scalar('val_epoch_loss', loss, step=epoch)
        writer.flush()

      # Evaluate on test datasets every test_freq epochs.
      if epoch == 1 or epoch % FLAGS.test_freq == 0:
        train_utils.test_retrieval(
                FLAGS.test_datasets, model, writer=writer, epoch=epoch,
                model_directory=model_directory,
                precompute_whitening=FLAGS.precompute_whitening,
                data_root=FLAGS.data_root, multiscale=FLAGS.multiscale)

      # Saving checkpoints and model weights.
      try:
        save_path = manager.save(checkpoint_number=epoch)
        global_features_utils.debug_and_log(
                'Saved ({}) at {}'.format(epoch, save_path))

        filename = os.path.join(model_directory,
                                'checkpoint_epoch_{}.h5'.format(epoch))
        model.save_weights(filename, save_format='h5')
        global_features_utils.debug_and_log(
                'Saved weights ({}) at {}'.format(epoch, filename))
      except Exception as ex:
        global_features_utils.debug_and_log(
                'Could not save checkpoint: {}'.format(ex))