def evaluate(): """Extract embeddings.""" logdir = FLAGS.logdir setup_eval_dir(logdir) # Can ignore frame labels if dataset doesn't have per-frame labels. CONFIG.DATA.FRAME_LABELS = FLAGS.keep_labels # Subsample frames in case videos are long or fps is high to save memory. CONFIG.DATA.SAMPLE_ALL_STRIDE = FLAGS.sample_all_stride algo = get_algo(CONFIG.TRAINING_ALGO) _, optimizer, _ = get_lr_opt_global_step() restore_ckpt(logdir=logdir, optimizer=optimizer, **algo.model) if FLAGS.defun: algo.call = tf.function(algo.call) algo.compute_loss = tf.function(algo.compute_loss) iterator = create_one_epoch_dataset(FLAGS.dataset, FLAGS.split, mode='eval', path_to_tfrecords=FLAGS.path_to_tfrecords) max_embs = None if FLAGS.max_embs <= 0 else FLAGS.max_embs embeddings = get_embeddings_dataset( algo.model, iterator, frames_per_batch=FLAGS.frames_per_batch, keep_data=FLAGS.keep_data, keep_labels=FLAGS.keep_labels, max_embs=max_embs) np.save(gfile.Open(FLAGS.save_path, 'w'), embeddings)
def evaluate_once(algo, iterator_tasks, embedding_tasks, iterators, summary_writer): """Evaluate learnt embeddings on downstream tasks.""" # Sets up model for training. _, optimizer, global_step = get_lr_opt_global_step() restore_ckpt(logdir=CONFIG.LOGDIR, optimizer=optimizer, **algo.model) if global_step.numpy() == CONFIG.TRAIN.MAX_ITERS: global evaluated_last_ckpt evaluated_last_ckpt = True metrics = {} if iterator_tasks: with summary_writer.as_default(): with tf.summary.record_if(True): for task_name, task in iterator_tasks.items(): metrics[task_name] = task.evaluate(algo, global_step, iterators=iterators) max_embs = None if FLAGS.max_embs <= 0 else FLAGS.max_embs if embedding_tasks: frames_per_batch = CONFIG.EVAL.FRAMES_PER_BATCH for dataset_name in CONFIG.DATASETS: dataset = {'name': dataset_name} train_iterator = create_one_epoch_dataset( dataset_name, 'train', mode='eval', path_to_tfrecords=CONFIG.PATH_TO_TFRECORDS) dataset['train_dataset'] = get_embeddings_dataset( algo.model, train_iterator, frames_per_batch=frames_per_batch, max_embs=max_embs) val_iterator = create_one_epoch_dataset( dataset_name, 'val', mode='eval', path_to_tfrecords=CONFIG.PATH_TO_TFRECORDS) dataset['val_dataset'] = get_embeddings_dataset( algo.model, val_iterator, frames_per_batch=frames_per_batch, max_embs=max_embs) with summary_writer.as_default(): with tf.summary.record_if(True): for task_name, task in embedding_tasks.items(): if task_name not in metrics: metrics[task_name] = {} metrics[task_name][dataset_name] = task.evaluate( algo, global_step, embeddings_dataset=dataset) # Add all metrics in a separate tag so that analysis is easier. with summary_writer.as_default(): with tf.summary.record_if(True): for task_name in embedding_tasks.keys(): for dataset in CONFIG.DATASETS: tf.summary.scalar('metrics/%s_%s' % (dataset, task_name), metrics[task_name][dataset], step=global_step) avg_metric = sum(metrics[task_name].values()) avg_metric /= len(CONFIG.DATASETS) tf.summary.scalar('metrics/all_%s' % task_name, avg_metric, step=global_step)
def train(): """Trains model and evaluates on relevant downstream tasks.""" #print(CONFIG) CONFIG.LOGDIR = FLAGS.logdir logdir = CONFIG.LOGDIR setup_train_dir(logdir) # Common code for multigpu and single gpu. Set devices here if you don't # want to use all the GPUs on the machine. Default is to use all GPUs. strategy = tf.distribute.MirroredStrategy() with strategy.scope(): algo = get_algo(CONFIG.TRAINING_ALGO) # Setup summary writer. summary_writer = tf.summary.create_file_writer( os.path.join(logdir, 'train_logs'), flush_millis=10000) learning_rate, optimizer, global_step = get_lr_opt_global_step() ckpt_manager, _, _ = restore_ckpt( logdir=logdir, optimizer=optimizer, **algo.model) global_step_value = global_step.numpy() # Remember in Eager mode learning rate variable needs to be updated # manually. Calling lr_fn each iteration to get current learning rate. lr_fn = get_lr_fn(CONFIG.OPTIMIZER) # Setup Dataset Iterators from train and val datasets. batch_size_per_replica = CONFIG.TRAIN.BATCH_SIZE total_batch_size = batch_size_per_replica * strategy.num_replicas_in_sync train_ds = create_dataset('train', mode='train', batch_size=total_batch_size, return_iterator=False) train_iterator = strategy.make_dataset_iterator(train_ds) def train_step(data): steps = data['chosen_steps'] seq_lens = data['seq_lens'] loss = algo.train_one_iter(data, steps, seq_lens, global_step, optimizer) return loss # This reduction only affects reporting, not the gradients. # pylint: disable=g-long-lambda dist_train = lambda it: strategy.reduce( tf.distribute.ReduceOp.SUM, strategy.experimental_run(train_step, it), axis=None) # pylint: enable=g-long-lambda if FLAGS.defun: dist_train = tf.function(dist_train) stopwatch = Stopwatch() try: while global_step_value < CONFIG.TRAIN.MAX_ITERS: with summary_writer.as_default(): with tf.summary.record_if( global_step_value % CONFIG.LOGGING.REPORT_INTERVAL == 0): loss = dist_train(train_iterator) # Update learning rate based in lr_fn. learning_rate.assign(lr_fn(learning_rate, global_step)) tf.summary.scalar('loss', loss, step=global_step) tf.summary.scalar('learning_rate', learning_rate, step=global_step) # Save checkpoint. if global_step_value % CONFIG.CHECKPOINT.SAVE_INTERVAL == 0: ckpt_manager.save() logging.info('Checkpoint saved at iter %d.', global_step_value) # Update global step. global_step_value = global_step.numpy() time_per_iter = stopwatch.elapsed() tf.summary.scalar( 'timing/time_per_iter', time_per_iter, step=global_step) logging.info('Iter[{}/{}], {:.1f}s/iter, Loss: {:.3f}'.format( global_step_value, CONFIG.TRAIN.MAX_ITERS, time_per_iter, loss.numpy())) # Reset stopwatch after iter is complete. stopwatch.reset() except KeyboardInterrupt: logging.info('Caught keyboard interrupt. Saving model before quitting.') finally: # Save the final checkpoint. ckpt_manager.save() logging.info('Checkpoint saved at iter %d', global_step_value)