def launch_tensorboard(log_dir): """Runs tensorboard with the given `log_dir`. Args: log_dir: String, directory to launch tensorboard in. """ tensorboard = program.TensorBoard() tensorboard.configure(argv=[None, '--logdir', log_dir]) url = tensorboard.launch() global_features_utils.debug_and_log("Launching Tensorboard: {}".format(url))
def extract_global_descriptors_from_list(net, images, image_size, bounding_boxes=None, scales=[1.], multi_scale_power=1., print_freq=10): """Extracting global descriptors from a list of images. Args: net: Model object, network for the forward pass. images: Absolute image paths as strings. image_size: Integer, defines the maximum size of longer image side. bounding_boxes: List of (x1,y1,x2,y2) tuples to crop the query images. scales: List of float scales. multi_scale_power: Float, multi-scale normalization power parameter. print_freq: Printing frequency for debugging. Returns: descriptors: Global descriptors for the input images. """ # Creating dataset loader. data = generic_dataset.ImagesFromList(root='', image_paths=images, imsize=image_size, bounding_boxes=bounding_boxes) def _data_gen(): return (inst for inst in data) loader = tf.data.Dataset.from_generator(_data_gen, output_types=(tf.float32)) loader = loader.batch(1) # Extracting vectors. descriptors = tf.zeros((0, net.meta['outputdim'])) for i, input in enumerate(loader): if len(scales) == 1 and scales[0] == 1: descriptors = tf.concat([descriptors, net(input)], 0) else: descriptors = tf.concat([ descriptors, extract_multi_scale_descriptor(net, input, scales, multi_scale_power) ], 0) if (i + 1) % print_freq == 0 or (i + 1) == len(images): global_features_utils.debug_and_log('\r>>>> {}/{} done...'.format( (i + 1), len(images)), debug_on_the_same_line=True) global_features_utils.debug_and_log('', log=False) descriptors = tf.transpose(descriptors, perm=[1, 0]) return descriptors
def __init__(self, architecture='ResNet101', pooling='gem', whitening=False, pretrained=True, data_root=''): """GlobalFeatureNet network initialization. Args: architecture: Network backbone. pooling: Pooling method used 'mac'/'spoc'/'gem'. whitening: Bool, whether to use whitening. pretrained: Bool, whether to initialize the network with the weights pretrained on ImageNet. data_root: String, path to the data folder where the precomputed whitening is/will be saved in case `whitening` is True. Raises: ValueError: If `architecture` is not supported. """ if architecture not in _OUTPUT_DIM.keys(): raise ValueError( "Architecture {} is not supported.".format(architecture)) super(GlobalFeatureNet, self).__init__() # Get standard output dimensionality size. dim = _OUTPUT_DIM[architecture] if pretrained: # Initialize with network pretrained on imagenet. net_in = getattr(tf.keras.applications, architecture)(include_top=False, weights="imagenet") else: # Initialize with random weights. net_in = getattr(tf.keras.applications, architecture)(include_top=False, weights=None) # Initialize `feature_extractor`. Take only convolutions for # `feature_extractor`, always end with ReLU to make last activations # non-negative. if architecture.lower().startswith('densenet'): tmp_model = tf.keras.Sequential() tmp_model.add(net_in) net_in = tmp_model net_in.add(tf.keras.layers.ReLU()) # Initialize pooling. self.pool = _POOLING[pooling]() # Initialize whitening. if whitening: if pretrained and architecture in _WHITENING_CONFIG: # If precomputed whitening for the architecture exists, # the fully-connected layer is going to be initialized according to # the precomputed layer configuration. global_features_utils.debug_and_log( ">> {}: for '{}' custom computed whitening '{}' is used.". format(os.getcwd(), architecture, os.path.basename(_WHITENING_CONFIG[architecture]))) # The layer configuration is downloaded to the `data_root` folder. whiten_dir = os.path.join(data_root, architecture) path = tf.keras.utils.get_file( fname=whiten_dir, origin=_WHITENING_CONFIG[architecture]) # Whitening configuration is loaded. with tf.io.gfile.GFile(path, 'rb') as learned_whitening_file: whitening_config = pickle.load(learned_whitening_file) # Whitening layer is initialized according to the configuration. self.whiten = tf.keras.layers.Dense.from_config( whitening_config) else: # In case if no precomputed whitening exists for the chosen # architecture, the fully-connected whitening layer is initialized # with the random weights. self.whiten = tf.keras.layers.Dense(dim, activation=None, use_bias=True) global_features_utils.debug_and_log( ">> There is either no whitening computed for the " "used network architecture or pretrained is False," " random weights are used.") else: self.whiten = None # Create meta information to be stored in the network. self.meta = { 'architecture': architecture, 'pooling': pooling, 'whitening': whitening, 'outputdim': dim } self.feature_extractor = net_in self.normalize = normalization.L2Normalization()
def create_epoch_tuples(self, net): """Creates epoch tuples with the hard-negative re-mining. Negative examples are selected from clusters different than the cluster of the query image, as the clusters are ideally non-overlaping. For every query image we choose hard-negatives, that is, non-matching images with the most similar descriptor. Hard-negatives depend on the current CNN parameters. K-nearest neighbors from all non-matching images are selected. Query images are selected randomly. Positives examples are fixed for the related query image during the whole training process. Args: net: Model, network to be used for negative re-mining. Raises: ValueError: If the pool_size is smaller than the number of negative images per tuple. Returns: avg_l2: Float, average negative L2-distance. """ self._n = 0 if self._num_negatives < self._pool_size: raise ValueError( "Unable to create epoch tuples. Negative pool_size " "should be larger than the number of negative images " "per tuple.") global_features_utils.debug_and_log( '>> Creating tuples for an epoch of {}-{}...'.format( self._name, self._mode), True) global_features_utils.debug_and_log(">> Used network: ", True) global_features_utils.debug_and_log(net.meta_repr(), True) ## Selecting queries. # Draw `num_queries` random queries for the tuples. idx_list = np.arange(len(self._query_pool)) np.random.shuffle(idx_list) idxs2query_pool = idx_list[:self._num_queries] self._qidxs = [self._query_pool[i] for i in idxs2query_pool] ## Selecting positive pairs. # Positives examples are fixed for each query during the whole training # process. self._pidxs = [self._positive_pool[i] for i in idxs2query_pool] ## Selecting negative pairs. # If `num_negatives` = 0 create dummy nidxs. # Useful when only positives used for training. if self._num_negatives == 0: self._nidxs = [[] for _ in range(len(self._qidxs))] return 0 # Draw pool_size random images for pool of negatives images. neg_idx_list = np.arange(len(self.images)) np.random.shuffle(neg_idx_list) neg_images_idxs = neg_idx_list[:self._pool_size] global_features_utils.debug_and_log( '>> Extracting descriptors for query images...', debug=True) img_list = self._img_names_to_full_path( [self.images[i] for i in self._qidxs]) qvecs = global_model.extract_global_descriptors_from_list( net, images=img_list, image_size=self._imsize, print_freq=self._print_freq) global_features_utils.debug_and_log( '>> Extracting descriptors for negative pool...', debug=True) poolvecs = global_model.extract_global_descriptors_from_list( net, images=self._img_names_to_full_path( [self.images[i] for i in neg_images_idxs]), image_size=self._imsize, print_freq=self._print_freq) global_features_utils.debug_and_log( '>> Searching for hard negatives...', debug=True) # Compute dot product scores and ranks. scores = tf.linalg.matmul(poolvecs, qvecs, transpose_a=True) ranks = tf.argsort(scores, axis=0, direction='DESCENDING') sum_ndist = 0. n_ndist = 0. # Selection of negative examples. self._nidxs = [] for q, qidx in enumerate(self._qidxs): # We are not using the query cluster, those images are potentially # positive. qcluster = self._clusters[qidx] clusters = [qcluster] nidxs = [] rank = 0 while len(nidxs) < self._num_negatives: if rank >= tf.shape(ranks)[0]: raise ValueError( "Unable to create epoch tuples. Number of required " "negative images is larger than the number of " "clusters in the dataset.") potential = neg_images_idxs[ranks[rank, q]] # Take at most one image from the same cluster. if not self._clusters[potential] in clusters: nidxs.append(potential) clusters.append(self._clusters[potential]) dist = tf.norm(qvecs[:, q] - poolvecs[:, ranks[rank, q]], axis=0).numpy() sum_ndist += dist n_ndist += 1 rank += 1 self._nidxs.append(nidxs) global_features_utils.debug_and_log( '>> Average negative l2-distance: {:.2f}'.format(sum_ndist / n_ndist)) # Return average negative L2-distance. return sum_ndist / n_ndist
def train_val_one_epoch(loader, model, criterion, optimizer, epoch, train=True, batch_size=5, query_size=2000, neg_num=5, update_every=1, debug=False): """Executes either training or validation step based on `train` value. Args: loader: Training/validation iterable dataset. model: Network to train/validate. criterion: Loss function. optimizer: Network optimizer. epoch: Integer, epoch number. train: Bool, specifies training or validation phase. batch_size: Integer, number of (q,p,n1,...,nN) tuples in a mini-batch. query_size: Integer, number of queries randomly drawn per one training epoch. neg_num: Integer, number of negatives per a tuple. update_every: Integer, update model weights every N batches, used to handle relatively large batches batch_size effectively becomes update_every x batch_size. debug: Bool, whether debug mode is used. Returns: average_epoch_loss: Average epoch loss. """ batch_time = global_features_utils.AverageMeter() data_time = global_features_utils.AverageMeter() losses = global_features_utils.AverageMeter() # Retrieve all trainable variables we defined in the graph. tvs = model.trainable_variables accum_grads = [tf.zeros_like(tv.read_value()) for tv in tvs] end = time.time() batch_num = 0 print_frequency = 10 all_batch_num = query_size // batch_size state = 'Train' if train else 'Val' global_features_utils.debug_and_log('>> {} step:'.format(state)) # For every batch in the dataset; Stops when all batches in the dataset have # been processed. while True: data_time.update(time.time() - end) if train: try: # Train on one batch. # Each image in the batch is loaded into memory consecutively. for _ in range(batch_size): # Because the images are not necessarily of the same size, we can't # set the batch size with .batch(). batch = loader.get_next() input_tuple = batch[0:-1] target_tuple = batch[-1] loss_value, grads = _compute_loss_and_gradient( criterion, model, input_tuple, target_tuple, neg_num) losses.update(loss_value) # Accumulate gradients. accum_grads += grads # Perform weight update if required. if (batch_num + 1) % update_every == 0 or (batch_num + 1) == all_batch_num: # Do one step for multiple batches. Accumulated gradients are # used. optimizer.apply_gradients( zip(accum_grads, model.trainable_variables)) accum_grads = [ tf.zeros_like(tv.read_value()) for tv in tvs ] # We break when we run out of range, i.e., we exhausted all dataset # images. except tf.errors.OutOfRangeError: break else: # Validate one batch. # We load full batch into memory. input = [] target = [] try: for _ in range(batch_size): # Because the images are not necessarily of the same size, we can't # set the batch size with .batch(). batch = loader.get_next() input.append(batch[0:-1]) target.append(batch[-1]) # We break when we run out of range, i.e., we exhausted all dataset # images. except tf.errors.OutOfRangeError: break descriptors = tf.zeros(shape=(0, model.meta['outputdim']), dtype=tf.float32) for input_tuple in input: for img in input_tuple: # Compute the global descriptor vector. model_out = model(tf.expand_dims(img, axis=0), training=False) descriptors = tf.concat([descriptors, model_out], 0) # No need to reduce memory consumption (no backward pass): # Compute loss for the full batch. queries = descriptors[target == -1] positives = descriptors[target == 1] negatives = descriptors[target == 0] negatives = tf.reshape( negatives, [tf.shape(queries)[0], neg_num, model.meta['outputdim']]) loss = criterion(queries, positives, negatives) # Record loss. losses.update(loss / batch_size, batch_size) # Measure elapsed time. batch_time.update(time.time() - end) end = time.time() # Record immediate loss and elapsed time. if debug and ((batch_num + 1) % print_frequency == 0 or batch_num == 0 or (batch_num + 1) == all_batch_num): global_features_utils.debug_and_log( '>> {0}: [{1} epoch][{2}/{3} batch]\t Time val: {' 'batch_time.val:.3f} ' '(Batch Time avg: {batch_time.avg:.3f})\t Data {' 'data_time.val:.3f} (' 'Time avg: {data_time.avg:.3f})\t Immediate loss value: {' 'loss.val:.4f} ' '(Loss avg: {loss.avg:.4f})'.format(state, epoch, batch_num + 1, all_batch_num, batch_time=batch_time, data_time=data_time, loss=losses), debug=True, log=False) batch_num += 1 return losses.avg
def test_retrieval(datasets, net, epoch, writer=None, model_directory=None, precompute_whitening=None, data_root='data', multiscale=[1.], test_image_size=1024): """Testing step. Evaluates the network on the provided test datasets by computing single-scale mAP for easy/medium/hard cases. If `writer` is specified, saves the mAP values in a tensorboard supported format. Args: datasets: List of dataset names for model testing (from `_TEST_DATASET_NAMES`). net: Network to evaluate. epoch: Integer, epoch number. writer: Tensorboard writer. model_directory: String, path to the model directory. precompute_whitening: Dataset used to learn whitening. If no precomputation required, then `None`. Only 'retrieval-SfM-30k' and 'retrieval-SfM-120k' datasets are supported for whitening pre-computation. data_root: Absolute path to the data folder. multiscale: List of scales for multiscale testing. test_image_size: Integer, maximum size of the test images. """ global_features_utils.debug_and_log(">> Testing step:") global_features_utils.debug_and_log( '>> Evaluating network on test datasets...') # Precompute whitening. if precompute_whitening is not None: # If whitening already precomputed, load it and skip the computations. filename = os.path.join( model_directory, 'learned_whitening_mP_{}_epoch.pkl'.format(epoch)) filename_layer = os.path.join( model_directory, 'learned_whitening_layer_config_{}_epoch.pkl'.format(epoch)) if tf.io.gfile.exists(filename): global_features_utils.debug_and_log( '>> {}: Whitening for this epoch is already precomputed. ' 'Loading...'.format(precompute_whitening)) with tf.io.gfile.GFile(filename, 'rb') as learned_whitening_file: learned_whitening = pickle.load(learned_whitening_file) else: start = time.time() global_features_utils.debug_and_log( '>> {}: Learning whitening...'.format(precompute_whitening)) # Loading db. db_root = os.path.join(data_root, 'train', precompute_whitening) ims_root = os.path.join(db_root, 'ims') db_filename = os.path.join( db_root, '{}-whiten.pkl'.format(precompute_whitening)) with tf.io.gfile.GFile(db_filename, 'rb') as f: db = pickle.load(f) images = [ sfm120k.id2filename(db['cids'][i], ims_root) for i in range(len(db['cids'])) ] # Extract whitening vectors. global_features_utils.debug_and_log( '>> {}: Extracting...'.format(precompute_whitening)) wvecs = global_model.extract_global_descriptors_from_list( net, images, test_image_size) # Learning whitening. global_features_utils.debug_and_log( '>> {}: Learning...'.format(precompute_whitening)) wvecs = wvecs.numpy() mean_vector, projection_matrix = whiten.whitenlearn( wvecs, db['qidxs'], db['pidxs']) learned_whitening = {'m': mean_vector, 'P': projection_matrix} global_features_utils.debug_and_log( '>> {}: Elapsed time: {}'.format( precompute_whitening, global_features_utils.htime(time.time() - start))) # Save learned_whitening parameters for a later use. with tf.io.gfile.GFile(filename, 'wb') as learned_whitening_file: pickle.dump(learned_whitening, learned_whitening_file) # Saving whitening as a layer. bias = -np.dot(mean_vector.T, projection_matrix.T) whitening_layer = tf.keras.layers.Dense( net.meta['outputdim'], activation=None, use_bias=True, kernel_initializer=tf.keras.initializers.Constant( projection_matrix.T), bias_initializer=tf.keras.initializers.Constant(bias)) with tf.io.gfile.GFile(filename_layer, 'wb') as learned_whitening_file: pickle.dump(whitening_layer.get_config(), learned_whitening_file) else: learned_whitening = None # Evaluate on test datasets. for dataset in datasets: start = time.time() # Prepare config structure for the test dataset. cfg = test_dataset.CreateConfigForTestDataset(dataset, os.path.join(data_root)) images = [cfg['im_fname'](cfg, i) for i in range(cfg['n'])] qimages = [cfg['qim_fname'](cfg, i) for i in range(cfg['nq'])] bounding_boxes = [ tuple(cfg['gnd'][i]['bbx']) for i in range(cfg['nq']) ] # Extract database and query vectors. global_features_utils.debug_and_log( '>> {}: Extracting database images...'.format(dataset)) vecs = global_model.extract_global_descriptors_from_list( net, images, test_image_size, scales=multiscale) global_features_utils.debug_and_log( '>> {}: Extracting query images...'.format(dataset)) qvecs = global_model.extract_global_descriptors_from_list( net, qimages, test_image_size, bounding_boxes, scales=multiscale) global_features_utils.debug_and_log( '>> {}: Evaluating...'.format(dataset)) # Convert the obtained descriptors to numpy. vecs = vecs.numpy() qvecs = qvecs.numpy() # Search, rank and print test set metrics. _calculate_metrics_and_export_to_tensorboard(vecs, qvecs, dataset, cfg, writer, epoch, whiten=False) if learned_whitening is not None: # Whiten the vectors. mean_vector = learned_whitening['m'] projection_matrix = learned_whitening['P'] vecs_lw = whiten.whitenapply(vecs, mean_vector, projection_matrix) qvecs_lw = whiten.whitenapply(qvecs, mean_vector, projection_matrix) # Search, rank, and print. _calculate_metrics_and_export_to_tensorboard(vecs_lw, qvecs_lw, dataset, cfg, writer, epoch, whiten=True) global_features_utils.debug_and_log('>> {}: Elapsed time: {}'.format( dataset, global_features_utils.htime(time.time() - start)))
def main(argv): if len(argv) > 1: raise RuntimeError('Too many command-line arguments.') # Manually check if there are unknown test datasets and if the dataset # ground truth files are downloaded. for dataset in FLAGS.test_datasets: if dataset not in _TEST_DATASET_NAMES: raise ValueError('Unsupported or unknown test dataset: {}.'.format( dataset)) test_data_config = os.path.join(FLAGS.data_root, 'gnd_{}.pkl'.format(dataset)) if not tf.io.gfile.exists(test_data_config): raise ValueError( '{} ground truth file at {} not found. Please download it ' 'according to ' 'the DELG instructions.'.format(dataset, FLAGS.data_root)) # Check if train dataset is downloaded and download it if not found. dataset_download.download_train(FLAGS.data_root) # Creating model export directory if it does not exist. model_directory = global_features_utils.create_model_directory( FLAGS.training_dataset, FLAGS.arch, FLAGS.pool, FLAGS.whitening, FLAGS.pretrained, FLAGS.loss, FLAGS.loss_margin, FLAGS.optimizer, FLAGS.lr, FLAGS.weight_decay, FLAGS.neg_num, FLAGS.query_size, FLAGS.pool_size, FLAGS.batch_size, FLAGS.update_every, FLAGS.image_size, FLAGS.directory) # Setting up logging directory, same as where the model is stored. logging.get_absl_handler().use_absl_log_file('absl_logging', model_directory) # Set cuda visible device. os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_id global_features_utils.debug_and_log('>> Num GPUs Available: {}'.format( len(tf.config.experimental.list_physical_devices('GPU'))), FLAGS.debug) # Set random seeds. tf.random.set_seed(0) np.random.seed(0) # Initialize the model. if FLAGS.pretrained: global_features_utils.debug_and_log( '>> Using pre-trained model \'{}\''.format(FLAGS.arch)) else: global_features_utils.debug_and_log( '>> Using model from scratch (random weights) \'{}\'.'.format( FLAGS.arch)) model_params = {'architecture': FLAGS.arch, 'pooling': FLAGS.pool, 'whitening': FLAGS.whitening, 'pretrained': FLAGS.pretrained, 'data_root': FLAGS.data_root} model = global_model.GlobalFeatureNet(**model_params) # Freeze running mean and std in batch normalization layers. # We do training one image at a time to improve memory requirements of # the network; therefore, the computed statistics would not be per a # batch. Instead, we choose freezing - setting the parameters of all # batch norm layers in the network to non-trainable (i.e., using original # imagenet statistics). for layer in model.feature_extractor.layers: if isinstance(layer, tf.keras.layers.BatchNormalization): layer.trainable = False global_features_utils.debug_and_log('>> Network initialized.') global_features_utils.debug_and_log('>> Loss: {}.'.format(FLAGS.loss)) # Define the loss function. if FLAGS.loss == 'contrastive': criterion = ranking_losses.ContrastiveLoss(margin=FLAGS.loss_margin) elif FLAGS.loss == 'triplet': criterion = ranking_losses.TripletLoss(margin=FLAGS.loss_margin) else: raise ValueError('Loss {} not available.'.format(FLAGS.loss)) # Defining parameters for the training. # When pre-computing whitening, we run evaluation before the network training # and the `start_epoch` is set to 0. In other cases, we start from epoch 1. start_epoch = 1 exp_decay = math.exp(-0.01) decay_steps = FLAGS.query_size / FLAGS.batch_size # Define learning rate decay schedule. lr_scheduler = tf.keras.optimizers.schedules.ExponentialDecay( initial_learning_rate=FLAGS.lr, decay_steps=decay_steps, decay_rate=exp_decay) # Define the optimizer. if FLAGS.optimizer == 'sgd': opt = tfa.optimizers.extend_with_decoupled_weight_decay( tf.keras.optimizers.SGD) optimizer = opt(weight_decay=FLAGS.weight_decay, learning_rate=lr_scheduler, momentum=FLAGS.momentum) elif FLAGS.optimizer == 'adam': opt = tfa.optimizers.extend_with_decoupled_weight_decay( tf.keras.optimizers.Adam) optimizer = opt(weight_decay=FLAGS.weight_decay, learning_rate=lr_scheduler) else: raise ValueError('Optimizer {} not available.'.format(FLAGS.optimizer)) # Initializing logging. writer = tf.summary.create_file_writer(model_directory) tf.summary.experimental.set_step(1) # Setting up the checkpoint manager. checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) manager = tf.train.CheckpointManager( checkpoint, model_directory, max_to_keep=10, keep_checkpoint_every_n_hours=3) if FLAGS.resume: # Restores the checkpoint, if existing. global_features_utils.debug_and_log('>> Continuing from a checkpoint.') checkpoint.restore(manager.latest_checkpoint) # Launching tensorboard if required. if FLAGS.launch_tensorboard: tensorboard = tf.keras.callbacks.TensorBoard(model_directory) tensorboard.set_model(model=model) tensorboard_utils.launch_tensorboard(log_dir=model_directory) # Log flags used. global_features_utils.debug_and_log('>> Running training script with:') global_features_utils.debug_and_log('>> logdir = {}'.format(model_directory)) if FLAGS.training_dataset.startswith('retrieval-SfM-120k'): train_dataset = sfm120k.CreateDataset( data_root=FLAGS.data_root, mode='train', imsize=FLAGS.image_size, num_negatives=FLAGS.neg_num, num_queries=FLAGS.query_size, pool_size=FLAGS.pool_size ) if FLAGS.validation_type is not None: val_dataset = sfm120k.CreateDataset( data_root=FLAGS.data_root, mode='val', imsize=FLAGS.image_size, num_negatives=FLAGS.neg_num, num_queries=float('Inf'), pool_size=float('Inf'), eccv2020=True if FLAGS.validation_type == 'eccv2020' else False ) train_dataset_output_types = [tf.float32 for i in range(2 + FLAGS.neg_num)] train_dataset_output_types.append(tf.int32) global_features_utils.debug_and_log( '>> Training the {} network'.format(model_directory)) global_features_utils.debug_and_log('>> GPU ids: {}'.format(FLAGS.gpu_id)) with writer.as_default(): # Precompute whitening if needed. if FLAGS.precompute_whitening is not None: epoch = 0 train_utils.test_retrieval( FLAGS.test_datasets, model, writer=writer, epoch=epoch, model_directory=model_directory, precompute_whitening=FLAGS.precompute_whitening, data_root=FLAGS.data_root, multiscale=FLAGS.multiscale) for epoch in range(start_epoch, FLAGS.epochs + 1): # Set manual seeds per epoch. np.random.seed(epoch) tf.random.set_seed(epoch) # Find hard-negatives. # While hard-positive examples are fixed during the whole training # process and are randomly chosen from every epoch; hard-negatives # depend on the current CNN parameters and are re-mined once per epoch. avg_neg_distance = train_dataset.create_epoch_tuples(model) def _train_gen(): return (inst for inst in train_dataset) train_loader = tf.data.Dataset.from_generator( _train_gen, output_types=tuple(train_dataset_output_types)) loss = train_utils.train_val_one_epoch( loader=iter(train_loader), model=model, criterion=criterion, optimizer=optimizer, epoch=epoch, batch_size=FLAGS.batch_size, query_size=FLAGS.query_size, neg_num=FLAGS.neg_num, update_every=FLAGS.update_every, debug=FLAGS.debug) # Write a scalar summary. tf.summary.scalar('train_epoch_loss', loss, step=epoch) # Forces summary writer to send any buffered data to storage. writer.flush() # Evaluate on validation set. if FLAGS.validation_type is not None and (epoch % FLAGS.test_freq == 0 or epoch == 1): avg_neg_distance = val_dataset.create_epoch_tuples(model, model_directory) def _val_gen(): return (inst for inst in val_dataset) val_loader = tf.data.Dataset.from_generator( _val_gen, output_types=tuple(train_dataset_output_types)) loss = train_utils.train_val_one_epoch( loader=iter(val_loader), model=model, criterion=criterion, optimizer=None, epoch=epoch, train=False, batch_size=FLAGS.batch_size, query_size=FLAGS.query_size, neg_num=FLAGS.neg_num, update_every=FLAGS.update_every, debug=FLAGS.debug) tf.summary.scalar('val_epoch_loss', loss, step=epoch) writer.flush() # Evaluate on test datasets every test_freq epochs. if epoch == 1 or epoch % FLAGS.test_freq == 0: train_utils.test_retrieval( FLAGS.test_datasets, model, writer=writer, epoch=epoch, model_directory=model_directory, precompute_whitening=FLAGS.precompute_whitening, data_root=FLAGS.data_root, multiscale=FLAGS.multiscale) # Saving checkpoints and model weights. try: save_path = manager.save(checkpoint_number=epoch) global_features_utils.debug_and_log( 'Saved ({}) at {}'.format(epoch, save_path)) filename = os.path.join(model_directory, 'checkpoint_epoch_{}.h5'.format(epoch)) model.save_weights(filename, save_format='h5') global_features_utils.debug_and_log( 'Saved weights ({}) at {}'.format(epoch, filename)) except Exception as ex: global_features_utils.debug_and_log( 'Could not save checkpoint: {}'.format(ex))