def _get_dataset_and_samples(config, train_crop_size, dataset_name,
                             dataset_dir, first_frame_finetuning,
                             three_frame_dataset):
    """Creates dataset object and samples dict of tensor.

  Args:
    config: A DeploymentConfig.
    train_crop_size: Integer, the crop size used for training.
    dataset_name: String, the name of the dataset.
    dataset_dir: String, the directory of the dataset.
    first_frame_finetuning: Boolean, whether the used dataset is a dataset
      for first frame fine-tuning.
    three_frame_dataset: Boolean, whether the dataset has exactly three frames
      per video of which the first is to be used as reference and the two
      others are consecutive frames to be used as query frames.

  Returns:
    dataset: An instance of slim Dataset.
    samples: A dictionary of tensors for semantic segmentation.
  """

    # Split the batch across GPUs.
    assert FLAGS.train_batch_size % config.num_clones == 0, (
        'Training batch size not divisble by number of clones (GPUs).')

    clone_batch_size = FLAGS.train_batch_size / config.num_clones

    if first_frame_finetuning:
        train_split = 'val'
    else:
        train_split = FLAGS.train_split

    data_type = 'tf_sequence_example'
    # Get dataset-dependent information.
    dataset = video_dataset.get_dataset(dataset_name,
                                        train_split,
                                        dataset_dir=dataset_dir,
                                        data_type=data_type)

    tf.gfile.MakeDirs(FLAGS.train_logdir)
    tf.logging.info('Training on %s set', train_split)

    samples = video_input_generator.get(
        dataset,
        FLAGS.train_num_frames_per_video,
        train_crop_size,
        clone_batch_size,
        num_readers=FLAGS.num_readers,
        num_threads=FLAGS.batch_num_threads,
        min_resize_value=FLAGS.min_resize_value,
        max_resize_value=FLAGS.max_resize_value,
        resize_factor=FLAGS.resize_factor,
        min_scale_factor=FLAGS.min_scale_factor,
        max_scale_factor=FLAGS.max_scale_factor,
        scale_factor_step_size=FLAGS.scale_factor_step_size,
        dataset_split=FLAGS.train_split,
        is_training=True,
        model_variant=FLAGS.model_variant,
        batch_capacity_factor=FLAGS.batch_capacity_factor,
        decoder_output_stride=common.parse_decoder_output_stride(),
        first_frame_finetuning=first_frame_finetuning,
        sample_only_first_frame_for_finetuning=FLAGS.
        sample_only_first_frame_for_finetuning,
        sample_adjacent_and_consistent_query_frames=FLAGS.
        sample_adjacent_and_consistent_query_frames
        or FLAGS.use_softmax_feedback,
        remap_labels_to_reference_frame=True,
        three_frame_dataset=three_frame_dataset,
        add_prev_frame_label=not FLAGS.also_attend_to_previous_frame)
    return dataset, samples
Exemplo n.º 2
0
def _get_dataset_and_samples(config, train_crop_size, dataset_name,
                             dataset_dir, first_frame_finetuning,
                             three_frame_dataset):
  """Creates dataset object and samples dict of tensor.

  Args:
    config: A DeploymentConfig.
    train_crop_size: Integer, the crop size used for training.
    dataset_name: String, the name of the dataset.
    dataset_dir: String, the directory of the dataset.
    first_frame_finetuning: Boolean, whether the used dataset is a dataset
      for first frame fine-tuning.
    three_frame_dataset: Boolean, whether the dataset has exactly three frames
      per video of which the first is to be used as reference and the two
      others are consecutive frames to be used as query frames.

  Returns:
    dataset: An instance of slim Dataset.
    samples: A dictionary of tensors for semantic segmentation.
  """

  # Split the batch across GPUs.
  assert FLAGS.train_batch_size % config.num_clones == 0, (
      'Training batch size not divisble by number of clones (GPUs).')

  clone_batch_size = FLAGS.train_batch_size / config.num_clones

  if first_frame_finetuning:
    train_split = 'val'
  else:
    train_split = FLAGS.train_split

  data_type = 'tf_sequence_example'
  # Get dataset-dependent information.
  dataset = video_dataset.get_dataset(
      dataset_name,
      train_split,
      dataset_dir=dataset_dir,
      data_type=data_type)

  tf.gfile.MakeDirs(FLAGS.train_logdir)
  tf.logging.info('Training on %s set', train_split)

  samples = video_input_generator.get(
      dataset,
      FLAGS.train_num_frames_per_video,
      train_crop_size,
      clone_batch_size,
      num_readers=FLAGS.num_readers,
      num_threads=FLAGS.batch_num_threads,
      min_resize_value=FLAGS.min_resize_value,
      max_resize_value=FLAGS.max_resize_value,
      resize_factor=FLAGS.resize_factor,
      min_scale_factor=FLAGS.min_scale_factor,
      max_scale_factor=FLAGS.max_scale_factor,
      scale_factor_step_size=FLAGS.scale_factor_step_size,
      dataset_split=FLAGS.train_split,
      is_training=True,
      model_variant=FLAGS.model_variant,
      batch_capacity_factor=FLAGS.batch_capacity_factor,
      decoder_output_stride=common.parse_decoder_output_stride(),
      first_frame_finetuning=first_frame_finetuning,
      sample_only_first_frame_for_finetuning=
      FLAGS.sample_only_first_frame_for_finetuning,
      sample_adjacent_and_consistent_query_frames=
      FLAGS.sample_adjacent_and_consistent_query_frames or
      FLAGS.use_softmax_feedback,
      remap_labels_to_reference_frame=True,
      three_frame_dataset=three_frame_dataset,
      add_prev_frame_label=not FLAGS.also_attend_to_previous_frame
  )
  return dataset, samples
Exemplo n.º 3
0
def main(unused_argv):
  if FLAGS.vis_batch_size != 1:
    raise ValueError('Only batch size 1 is supported for now')

  data_type = 'tf_sequence_example'
  # Get dataset-dependent information.
  dataset = video_dataset.get_dataset(
      FLAGS.dataset,
      FLAGS.vis_split,
      dataset_dir=FLAGS.dataset_dir,
      data_type=data_type)

  # Prepare for visualization.
  tf.gfile.MakeDirs(FLAGS.vis_logdir)
  segmentation_dir = os.path.join(FLAGS.vis_logdir, _SEGMENTATION_SAVE_FOLDER)
  tf.gfile.MakeDirs(segmentation_dir)
  embeddings_dir = os.path.join(FLAGS.vis_logdir, _EMBEDDINGS_SAVE_FOLDER)
  tf.gfile.MakeDirs(embeddings_dir)
  num_vis_examples = (dataset.num_videos if (FLAGS.num_vis_examples < 0)
                      else FLAGS.num_vis_examples)
  if FLAGS.first_frame_finetuning:
    num_vis_examples = 1

  tf.logging.info('Visualizing on %s set', FLAGS.vis_split)
  g = tf.Graph()
  with g.as_default():
    # Without setting device to CPU we run out of memory.
    with tf.device('cpu:0'):
      samples = video_input_generator.get(
          dataset,
          None,
          None,
          FLAGS.vis_batch_size,
          min_resize_value=FLAGS.min_resize_value,
          max_resize_value=FLAGS.max_resize_value,
          resize_factor=FLAGS.resize_factor,
          dataset_split=FLAGS.vis_split,
          is_training=False,
          model_variant=FLAGS.model_variant,
          preprocess_image_and_label=False,
          remap_labels_to_reference_frame=False)
      samples[common.IMAGE] = tf.cast(samples[common.IMAGE], tf.float32)
      samples[common.LABEL] = tf.cast(samples[common.LABEL], tf.int32)
      first_frame_img = samples[common.IMAGE][0]
      reference_labels = samples[common.LABEL][0, tf.newaxis]
      gt_labels = tf.squeeze(samples[common.LABEL], axis=-1)
      seq_name = samples[common.VIDEO_ID][0]

    model_options = common.VideoModelOptions(
        outputs_to_num_classes={common.OUTPUT_TYPE: dataset.num_classes},
        crop_size=None,
        atrous_rates=FLAGS.atrous_rates,
        output_stride=FLAGS.output_stride)

    all_embeddings = None
    predicted_labels = create_predictions_fast(
        samples, reference_labels, first_frame_img, model_options)
    # If you need more options like saving embeddings, replace the call to
    # create_predictions_fast with create_predictions.

    tf.train.get_or_create_global_step()
    saver = tf.train.Saver(slim.get_variables_to_restore())
    sv = tf.train.Supervisor(graph=g,
                             logdir=FLAGS.vis_logdir,
                             init_op=tf.global_variables_initializer(),
                             summary_op=None,
                             summary_writer=None,
                             global_step=None,
                             saver=saver)
    num_batches = int(
        math.ceil(num_vis_examples / float(FLAGS.vis_batch_size)))
    last_checkpoint = None

    # Infinite loop to visualize the results when new checkpoint is created.
    while True:
      last_checkpoint = slim.evaluation.wait_for_new_checkpoint(
          FLAGS.checkpoint_dir, last_checkpoint)
      start = time.time()
      tf.logging.info(
          'Starting visualization at ' + time.strftime('%Y-%m-%d-%H:%M:%S',
                                                       time.gmtime()))
      tf.logging.info('Visualizing with model %s', last_checkpoint)

      all_ious = []
      with sv.managed_session(FLAGS.master,
                              start_standard_services=False) as sess:
        sv.start_queue_runners(sess)
        sv.saver.restore(sess, last_checkpoint)

        for batch in range(num_batches):
          ops = [predicted_labels, gt_labels, seq_name]
          if FLAGS.save_embeddings:
            ops.append(all_embeddings)
          tf.logging.info('Visualizing batch %d / %d', batch + 1, num_batches)
          res = sess.run(ops)
          tf.logging.info('Forwarding done')
          pred_labels_val, gt_labels_val, seq_name_val = res[:3]
          if FLAGS.save_embeddings:
            all_embeddings_val = res[3]
          else:
            all_embeddings_val = None
          seq_ious = _process_seq_data(segmentation_dir, embeddings_dir,
                                       seq_name_val, pred_labels_val,
                                       gt_labels_val, all_embeddings_val)
          all_ious.append(seq_ious)
      all_ious = np.concatenate(all_ious, axis=0)
      tf.logging.info('n_seqs %s, mIoU %f', all_ious.shape, all_ious.mean())
      tf.logging.info(
          'Finished visualization at ' + time.strftime('%Y-%m-%d-%H:%M:%S',
                                                       time.gmtime()))
      result_dir = FLAGS.vis_logdir + '/results/'
      tf.gfile.MakeDirs(result_dir)
      with tf.gfile.GFile(result_dir + seq_name_val + '.txt', 'w') as f:
        f.write(str(all_ious))
      if FLAGS.first_frame_finetuning or FLAGS.eval_once_and_quit:
        break
      time_to_next_eval = start + FLAGS.eval_interval_secs - time.time()
      if time_to_next_eval > 0:
        time.sleep(time_to_next_eval)
Exemplo n.º 4
0
def main(unused_argv):
    if FLAGS.vis_batch_size != 1:
        raise ValueError('Only batch size 1 is supported for now')

    data_type = 'tf_sequence_example'
    # Get dataset-dependent information.
    dataset = video_dataset.get_dataset(FLAGS.dataset,
                                        FLAGS.vis_split,
                                        dataset_dir=FLAGS.dataset_dir,
                                        data_type=data_type)

    # Prepare for visualization.
    tf.gfile.MakeDirs(FLAGS.vis_logdir)
    segmentation_dir = os.path.join(FLAGS.vis_logdir,
                                    _SEGMENTATION_SAVE_FOLDER)
    tf.gfile.MakeDirs(segmentation_dir)
    embeddings_dir = os.path.join(FLAGS.vis_logdir, _EMBEDDINGS_SAVE_FOLDER)
    tf.gfile.MakeDirs(embeddings_dir)
    num_vis_examples = (dataset.num_videos if (FLAGS.num_vis_examples < 0) else
                        FLAGS.num_vis_examples)
    if FLAGS.first_frame_finetuning:
        num_vis_examples = 1

    tf.logging.info('Visualizing on %s set', FLAGS.vis_split)
    g = tf.Graph()
    with g.as_default():
        # Without setting device to CPU we run out of memory.
        with tf.device('cpu:0'):
            samples = video_input_generator.get(
                dataset,
                None,
                None,
                FLAGS.vis_batch_size,
                min_resize_value=FLAGS.min_resize_value,
                max_resize_value=FLAGS.max_resize_value,
                resize_factor=FLAGS.resize_factor,
                dataset_split=FLAGS.vis_split,
                is_training=False,
                model_variant=FLAGS.model_variant,
                preprocess_image_and_label=False,
                remap_labels_to_reference_frame=False)
            samples[common.IMAGE] = tf.cast(samples[common.IMAGE], tf.float32)
            samples[common.LABEL] = tf.cast(samples[common.LABEL], tf.int32)
            first_frame_img = samples[common.IMAGE][0]
            reference_labels = samples[common.LABEL][0, tf.newaxis]
            gt_labels = tf.squeeze(samples[common.LABEL], axis=-1)
            seq_name = samples[common.VIDEO_ID][0]

        model_options = common.VideoModelOptions(
            outputs_to_num_classes={common.OUTPUT_TYPE: dataset.num_classes},
            crop_size=None,
            atrous_rates=FLAGS.atrous_rates,
            output_stride=FLAGS.output_stride)

        all_embeddings = None
        predicted_labels = create_predictions_fast(samples, reference_labels,
                                                   first_frame_img,
                                                   model_options)
        # If you need more options like saving embeddings, replace the call to
        # create_predictions_fast with create_predictions.

        tf.train.get_or_create_global_step()
        saver = tf.train.Saver(slim.get_variables_to_restore())
        sv = tf.train.Supervisor(graph=g,
                                 logdir=FLAGS.vis_logdir,
                                 init_op=tf.global_variables_initializer(),
                                 summary_op=None,
                                 summary_writer=None,
                                 global_step=None,
                                 saver=saver)
        num_batches = int(
            math.ceil(num_vis_examples / float(FLAGS.vis_batch_size)))
        last_checkpoint = None

        # Infinite loop to visualize the results when new checkpoint is created.
        while True:
            last_checkpoint = slim.evaluation.wait_for_new_checkpoint(
                FLAGS.checkpoint_dir, last_checkpoint)
            start = time.time()
            tf.logging.info('Starting visualization at ' +
                            time.strftime('%Y-%m-%d-%H:%M:%S', time.gmtime()))
            tf.logging.info('Visualizing with model %s', last_checkpoint)

            all_ious = []
            with sv.managed_session(FLAGS.master,
                                    start_standard_services=False) as sess:
                sv.start_queue_runners(sess)
                sv.saver.restore(sess, last_checkpoint)

                for batch in range(num_batches):
                    ops = [predicted_labels, gt_labels, seq_name]
                    if FLAGS.save_embeddings:
                        ops.append(all_embeddings)
                    tf.logging.info('Visualizing batch %d / %d', batch + 1,
                                    num_batches)
                    res = sess.run(ops)
                    tf.logging.info('Forwarding done')
                    pred_labels_val, gt_labels_val, seq_name_val = res[:3]
                    if FLAGS.save_embeddings:
                        all_embeddings_val = res[3]
                    else:
                        all_embeddings_val = None
                    seq_ious = _process_seq_data(segmentation_dir,
                                                 embeddings_dir, seq_name_val,
                                                 pred_labels_val,
                                                 gt_labels_val,
                                                 all_embeddings_val)
                    all_ious.append(seq_ious)
            all_ious = np.concatenate(all_ious, axis=0)
            tf.logging.info('n_seqs %s, mIoU %f', all_ious.shape,
                            all_ious.mean())
            tf.logging.info('Finished visualization at ' +
                            time.strftime('%Y-%m-%d-%H:%M:%S', time.gmtime()))
            result_dir = FLAGS.vis_logdir + '/results/'
            tf.gfile.MakeDirs(result_dir)
            with tf.gfile.GFile(result_dir + seq_name_val + '.txt', 'w') as f:
                f.write(str(all_ious))
            if FLAGS.first_frame_finetuning or FLAGS.eval_once_and_quit:
                break
            time_to_next_eval = start + FLAGS.eval_interval_secs - time.time()
            if time_to_next_eval > 0:
                time.sleep(time_to_next_eval)