def evaluate():
  """Extract embeddings."""

  logdir = FLAGS.logdir
  setup_eval_dir(logdir)
  # Can ignore frame labels if dataset doesn't have per-frame labels.
  CONFIG.DATA.FRAME_LABELS = FLAGS.keep_labels
  # Subsample frames in case videos are long or fps is high to save memory.
  CONFIG.DATA.SAMPLE_ALL_STRIDE = FLAGS.sample_all_stride

  algo = get_algo(CONFIG.TRAINING_ALGO)
  _, optimizer, _ = get_lr_opt_global_step()
  restore_ckpt(logdir=logdir, optimizer=optimizer, **algo.model)

  if FLAGS.defun:
    algo.call = tf.function(algo.call)
    algo.compute_loss = tf.function(algo.compute_loss)

  iterator = create_one_epoch_dataset(FLAGS.dataset, FLAGS.split, mode='eval',
                                      path_to_tfrecords=FLAGS.path_to_tfrecords)

  max_embs = None if FLAGS.max_embs <= 0 else FLAGS.max_embs
  embeddings = get_embeddings_dataset(
      algo.model,
      iterator,
      frames_per_batch=FLAGS.frames_per_batch,
      keep_data=FLAGS.keep_data,
      keep_labels=FLAGS.keep_labels,
      max_embs=max_embs)
  np.save(gfile.Open(FLAGS.save_path, 'w'), embeddings)
Beispiel #2
0
def evaluate_once(algo, iterator_tasks, embedding_tasks, iterators,
                  summary_writer):
    """Evaluate learnt embeddings on downstream tasks."""

    # Sets up model for training.
    _, optimizer, global_step = get_lr_opt_global_step()
    restore_ckpt(logdir=CONFIG.LOGDIR, optimizer=optimizer, **algo.model)

    if global_step.numpy() == CONFIG.TRAIN.MAX_ITERS:
        global evaluated_last_ckpt
        evaluated_last_ckpt = True

    metrics = {}
    if iterator_tasks:
        with summary_writer.as_default():
            with tf.summary.record_if(True):
                for task_name, task in iterator_tasks.items():
                    metrics[task_name] = task.evaluate(algo,
                                                       global_step,
                                                       iterators=iterators)

    max_embs = None if FLAGS.max_embs <= 0 else FLAGS.max_embs
    if embedding_tasks:
        frames_per_batch = CONFIG.EVAL.FRAMES_PER_BATCH
        for dataset_name in CONFIG.DATASETS:
            dataset = {'name': dataset_name}
            train_iterator = create_one_epoch_dataset(
                dataset_name,
                'train',
                mode='eval',
                path_to_tfrecords=CONFIG.PATH_TO_TFRECORDS)
            dataset['train_dataset'] = get_embeddings_dataset(
                algo.model,
                train_iterator,
                frames_per_batch=frames_per_batch,
                max_embs=max_embs)

            val_iterator = create_one_epoch_dataset(
                dataset_name,
                'val',
                mode='eval',
                path_to_tfrecords=CONFIG.PATH_TO_TFRECORDS)
            dataset['val_dataset'] = get_embeddings_dataset(
                algo.model,
                val_iterator,
                frames_per_batch=frames_per_batch,
                max_embs=max_embs)

            with summary_writer.as_default():
                with tf.summary.record_if(True):
                    for task_name, task in embedding_tasks.items():
                        if task_name not in metrics:
                            metrics[task_name] = {}
                        metrics[task_name][dataset_name] = task.evaluate(
                            algo, global_step, embeddings_dataset=dataset)

    # Add all metrics in a separate tag so that analysis is easier.
    with summary_writer.as_default():
        with tf.summary.record_if(True):
            for task_name in embedding_tasks.keys():
                for dataset in CONFIG.DATASETS:
                    tf.summary.scalar('metrics/%s_%s' % (dataset, task_name),
                                      metrics[task_name][dataset],
                                      step=global_step)
                avg_metric = sum(metrics[task_name].values())
                avg_metric /= len(CONFIG.DATASETS)
                tf.summary.scalar('metrics/all_%s' % task_name,
                                  avg_metric,
                                  step=global_step)
Beispiel #3
0
def train():
  """Trains model and evaluates on relevant downstream tasks."""
  #print(CONFIG)
  CONFIG.LOGDIR = FLAGS.logdir
  logdir = CONFIG.LOGDIR
  setup_train_dir(logdir)

  # Common code for multigpu and single gpu. Set devices here if you don't
  # want to use all the GPUs on the machine. Default is to use all GPUs.
  strategy = tf.distribute.MirroredStrategy()
  with strategy.scope():
    algo = get_algo(CONFIG.TRAINING_ALGO)

    # Setup summary writer.
    summary_writer = tf.summary.create_file_writer(
        os.path.join(logdir, 'train_logs'), flush_millis=10000)

    learning_rate, optimizer, global_step = get_lr_opt_global_step()
    ckpt_manager, _, _ = restore_ckpt(
        logdir=logdir, optimizer=optimizer, **algo.model)

    global_step_value = global_step.numpy()

    # Remember in Eager mode learning rate variable needs to be updated
    # manually. Calling lr_fn each iteration to get current learning rate.
    lr_fn = get_lr_fn(CONFIG.OPTIMIZER)

    # Setup Dataset Iterators from train and val datasets.
    batch_size_per_replica = CONFIG.TRAIN.BATCH_SIZE
    total_batch_size = batch_size_per_replica * strategy.num_replicas_in_sync
    train_ds = create_dataset('train', mode='train',
                              batch_size=total_batch_size,
                              return_iterator=False)
    train_iterator = strategy.make_dataset_iterator(train_ds)

    def train_step(data):
      steps = data['chosen_steps']
      seq_lens = data['seq_lens']
      loss = algo.train_one_iter(data, steps, seq_lens, global_step, optimizer)
      return loss

    # This reduction only affects reporting, not the gradients.
    # pylint: disable=g-long-lambda
    dist_train = lambda it: strategy.reduce(
        tf.distribute.ReduceOp.SUM, strategy.experimental_run(train_step, it),
        axis=None)
    # pylint: enable=g-long-lambda
    if FLAGS.defun:
      dist_train = tf.function(dist_train)

    stopwatch = Stopwatch()

    try:
      while global_step_value < CONFIG.TRAIN.MAX_ITERS:
        with summary_writer.as_default():
          with tf.summary.record_if(
              global_step_value % CONFIG.LOGGING.REPORT_INTERVAL == 0):

            loss = dist_train(train_iterator)

            # Update learning rate based in lr_fn.
            learning_rate.assign(lr_fn(learning_rate, global_step))

            tf.summary.scalar('loss', loss, step=global_step)
            tf.summary.scalar('learning_rate', learning_rate, step=global_step)

            # Save checkpoint.
            if global_step_value % CONFIG.CHECKPOINT.SAVE_INTERVAL == 0:
              ckpt_manager.save()
              logging.info('Checkpoint saved at iter %d.', global_step_value)

            # Update global step.
            global_step_value = global_step.numpy()

            time_per_iter = stopwatch.elapsed()

            tf.summary.scalar(
                'timing/time_per_iter', time_per_iter, step=global_step)

            logging.info('Iter[{}/{}], {:.1f}s/iter, Loss: {:.3f}'.format(
                global_step_value, CONFIG.TRAIN.MAX_ITERS, time_per_iter,
                loss.numpy()))
            # Reset stopwatch after iter is complete.
            stopwatch.reset()

    except KeyboardInterrupt:
      logging.info('Caught keyboard interrupt. Saving model before quitting.')

    finally:
      # Save the final checkpoint.
      ckpt_manager.save()
      logging.info('Checkpoint saved at iter %d', global_step_value)