def evaluate(): """Extract embeddings.""" logdir = FLAGS.logdir setup_eval_dir(logdir) # Can ignore frame labels if dataset doesn't have per-frame labels. CONFIG.DATA.FRAME_LABELS = FLAGS.keep_labels # Subsample frames in case videos are long or fps is high to save memory. CONFIG.DATA.SAMPLE_ALL_STRIDE = FLAGS.sample_all_stride algo = get_algo(CONFIG.TRAINING_ALGO) _, optimizer, _ = get_lr_opt_global_step() restore_ckpt(logdir=logdir, optimizer=optimizer, **algo.model) if FLAGS.defun: algo.call = tf.function(algo.call) algo.compute_loss = tf.function(algo.compute_loss) iterator = create_one_epoch_dataset(FLAGS.dataset, FLAGS.split, mode='eval', path_to_tfrecords=FLAGS.path_to_tfrecords) max_embs = None if FLAGS.max_embs <= 0 else FLAGS.max_embs embeddings = get_embeddings_dataset( algo.model, iterator, frames_per_batch=FLAGS.frames_per_batch, keep_data=FLAGS.keep_data, keep_labels=FLAGS.keep_labels, max_embs=max_embs) np.save(gfile.Open(FLAGS.save_path, 'w'), embeddings)
def evaluate(): """Evaluate embeddings.""" CONFIG.LOGDIR = FLAGS.logdir logdir = CONFIG.LOGDIR setup_eval_dir(logdir) algo = get_algo(CONFIG.TRAINING_ALGO) if FLAGS.defun: algo.call = tf.function(algo.call) algo.compute_loss = tf.function(algo.compute_loss) iterator_tasks, embedding_tasks = get_tasks(CONFIG.EVAL.TASKS) # Setup summary writer. summary_writer = tf.summary.create_file_writer(os.path.join( logdir, 'eval_logs'), flush_millis=10000) iterators = {} if iterator_tasks: # Setup Dataset Iterators from train and val datasets. iterators['train_iterator'] = create_dataset('train', mode='eval') iterators['val_iterator'] = create_dataset('val', mode='eval') if FLAGS.continuous_eval: for _ in tf.train.checkpoints_iterator(logdir, timeout=1, timeout_fn=timeout_fn): evaluate_once(algo, iterator_tasks, embedding_tasks, iterators, summary_writer) else: evaluate_once(algo, iterator_tasks, embedding_tasks, iterators, summary_writer)