def _get_dataset_and_samples(config, train_crop_size, dataset_name, dataset_dir, first_frame_finetuning, three_frame_dataset): """Creates dataset object and samples dict of tensor. Args: config: A DeploymentConfig. train_crop_size: Integer, the crop size used for training. dataset_name: String, the name of the dataset. dataset_dir: String, the directory of the dataset. first_frame_finetuning: Boolean, whether the used dataset is a dataset for first frame fine-tuning. three_frame_dataset: Boolean, whether the dataset has exactly three frames per video of which the first is to be used as reference and the two others are consecutive frames to be used as query frames. Returns: dataset: An instance of slim Dataset. samples: A dictionary of tensors for semantic segmentation. """ # Split the batch across GPUs. assert FLAGS.train_batch_size % config.num_clones == 0, ( 'Training batch size not divisble by number of clones (GPUs).') clone_batch_size = FLAGS.train_batch_size / config.num_clones if first_frame_finetuning: train_split = 'val' else: train_split = FLAGS.train_split data_type = 'tf_sequence_example' # Get dataset-dependent information. dataset = video_dataset.get_dataset( dataset_name, train_split, dataset_dir=dataset_dir, data_type=data_type) tf.gfile.MakeDirs(FLAGS.train_logdir) tf.logging.info('Training on %s set', train_split) samples = video_input_generator.get( dataset, FLAGS.train_num_frames_per_video, train_crop_size, clone_batch_size, num_readers=FLAGS.num_readers, num_threads=FLAGS.batch_num_threads, min_resize_value=FLAGS.min_resize_value, max_resize_value=FLAGS.max_resize_value, resize_factor=FLAGS.resize_factor, min_scale_factor=FLAGS.min_scale_factor, max_scale_factor=FLAGS.max_scale_factor, scale_factor_step_size=FLAGS.scale_factor_step_size, dataset_split=FLAGS.train_split, is_training=True, model_variant=FLAGS.model_variant, batch_capacity_factor=FLAGS.batch_capacity_factor, decoder_output_stride=common.parse_decoder_output_stride(), first_frame_finetuning=first_frame_finetuning, sample_only_first_frame_for_finetuning= FLAGS.sample_only_first_frame_for_finetuning, sample_adjacent_and_consistent_query_frames= FLAGS.sample_adjacent_and_consistent_query_frames or FLAGS.use_softmax_feedback, remap_labels_to_reference_frame=True, three_frame_dataset=three_frame_dataset, add_prev_frame_label=not FLAGS.also_attend_to_previous_frame ) return dataset, samples
def main(unused_argv): if FLAGS.vis_batch_size != 1: raise ValueError('Only batch size 1 is supported for now') data_type = 'tf_sequence_example' # Get dataset-dependent information. dataset = video_dataset.get_dataset( FLAGS.dataset, FLAGS.vis_split, dataset_dir=FLAGS.dataset_dir, data_type=data_type) # Prepare for visualization. tf.gfile.MakeDirs(FLAGS.vis_logdir) segmentation_dir = os.path.join(FLAGS.vis_logdir, _SEGMENTATION_SAVE_FOLDER) tf.gfile.MakeDirs(segmentation_dir) embeddings_dir = os.path.join(FLAGS.vis_logdir, _EMBEDDINGS_SAVE_FOLDER) tf.gfile.MakeDirs(embeddings_dir) num_vis_examples = (dataset.num_videos if (FLAGS.num_vis_examples < 0) else FLAGS.num_vis_examples) if FLAGS.first_frame_finetuning: num_vis_examples = 1 tf.logging.info('Visualizing on %s set', FLAGS.vis_split) g = tf.Graph() with g.as_default(): # Without setting device to CPU we run out of memory. with tf.device('cpu:0'): samples = video_input_generator.get( dataset, None, None, FLAGS.vis_batch_size, min_resize_value=FLAGS.min_resize_value, max_resize_value=FLAGS.max_resize_value, resize_factor=FLAGS.resize_factor, dataset_split=FLAGS.vis_split, is_training=False, model_variant=FLAGS.model_variant, preprocess_image_and_label=False, remap_labels_to_reference_frame=False) samples[common.IMAGE] = tf.cast(samples[common.IMAGE], tf.float32) samples[common.LABEL] = tf.cast(samples[common.LABEL], tf.int32) first_frame_img = samples[common.IMAGE][0] reference_labels = samples[common.LABEL][0, tf.newaxis] gt_labels = tf.squeeze(samples[common.LABEL], axis=-1) seq_name = samples[common.VIDEO_ID][0] model_options = common.VideoModelOptions( outputs_to_num_classes={common.OUTPUT_TYPE: dataset.num_classes}, crop_size=None, atrous_rates=FLAGS.atrous_rates, output_stride=FLAGS.output_stride) all_embeddings = None predicted_labels = create_predictions_fast( samples, reference_labels, first_frame_img, model_options) # If you need more options like saving embeddings, replace the call to # create_predictions_fast with create_predictions. tf.train.get_or_create_global_step() saver = tf.train.Saver(slim.get_variables_to_restore()) sv = tf.train.Supervisor(graph=g, logdir=FLAGS.vis_logdir, init_op=tf.global_variables_initializer(), summary_op=None, summary_writer=None, global_step=None, saver=saver) num_batches = int( math.ceil(num_vis_examples / float(FLAGS.vis_batch_size))) last_checkpoint = None # Infinite loop to visualize the results when new checkpoint is created. while True: last_checkpoint = slim.evaluation.wait_for_new_checkpoint( FLAGS.checkpoint_dir, last_checkpoint) start = time.time() tf.logging.info( 'Starting visualization at ' + time.strftime('%Y-%m-%d-%H:%M:%S', time.gmtime())) tf.logging.info('Visualizing with model %s', last_checkpoint) all_ious = [] with sv.managed_session(FLAGS.master, start_standard_services=False) as sess: sv.start_queue_runners(sess) sv.saver.restore(sess, last_checkpoint) for batch in range(num_batches): ops = [predicted_labels, gt_labels, seq_name] if FLAGS.save_embeddings: ops.append(all_embeddings) tf.logging.info('Visualizing batch %d / %d', batch + 1, num_batches) res = sess.run(ops) tf.logging.info('Forwarding done') pred_labels_val, gt_labels_val, seq_name_val = res[:3] if FLAGS.save_embeddings: all_embeddings_val = res[3] else: all_embeddings_val = None seq_ious = _process_seq_data(segmentation_dir, embeddings_dir, seq_name_val, pred_labels_val, gt_labels_val, all_embeddings_val) all_ious.append(seq_ious) all_ious = np.concatenate(all_ious, axis=0) tf.logging.info('n_seqs %s, mIoU %f', all_ious.shape, all_ious.mean()) tf.logging.info( 'Finished visualization at ' + time.strftime('%Y-%m-%d-%H:%M:%S', time.gmtime())) result_dir = FLAGS.vis_logdir + '/results/' tf.gfile.MakeDirs(result_dir) with tf.gfile.GFile(result_dir + seq_name_val + '.txt', 'w') as f: f.write(str(all_ious)) if FLAGS.first_frame_finetuning or FLAGS.eval_once_and_quit: break time_to_next_eval = start + FLAGS.eval_interval_secs - time.time() if time_to_next_eval > 0: time.sleep(time_to_next_eval)
def main(unused_argv): if FLAGS.vis_batch_size != 1: raise ValueError('Only batch size 1 is supported for now') data_type = 'tf_sequence_example' # Get dataset-dependent information. dataset = video_dataset.get_dataset(FLAGS.dataset, FLAGS.vis_split, dataset_dir=FLAGS.dataset_dir, data_type=data_type) # Prepare for visualization. tf.gfile.MakeDirs(FLAGS.vis_logdir) segmentation_dir = os.path.join(FLAGS.vis_logdir, _SEGMENTATION_SAVE_FOLDER) tf.gfile.MakeDirs(segmentation_dir) embeddings_dir = os.path.join(FLAGS.vis_logdir, _EMBEDDINGS_SAVE_FOLDER) tf.gfile.MakeDirs(embeddings_dir) num_vis_examples = (dataset.num_videos if (FLAGS.num_vis_examples < 0) else FLAGS.num_vis_examples) if FLAGS.first_frame_finetuning: num_vis_examples = 1 tf.logging.info('Visualizing on %s set', FLAGS.vis_split) g = tf.Graph() with g.as_default(): # Without setting device to CPU we run out of memory. with tf.device('cpu:0'): samples = video_input_generator.get( dataset, None, None, FLAGS.vis_batch_size, min_resize_value=FLAGS.min_resize_value, max_resize_value=FLAGS.max_resize_value, resize_factor=FLAGS.resize_factor, dataset_split=FLAGS.vis_split, is_training=False, model_variant=FLAGS.model_variant, preprocess_image_and_label=False, remap_labels_to_reference_frame=False) samples[common.IMAGE] = tf.cast(samples[common.IMAGE], tf.float32) samples[common.LABEL] = tf.cast(samples[common.LABEL], tf.int32) first_frame_img = samples[common.IMAGE][0] reference_labels = samples[common.LABEL][0, tf.newaxis] gt_labels = tf.squeeze(samples[common.LABEL], axis=-1) seq_name = samples[common.VIDEO_ID][0] model_options = common.VideoModelOptions( outputs_to_num_classes={common.OUTPUT_TYPE: dataset.num_classes}, crop_size=None, atrous_rates=FLAGS.atrous_rates, output_stride=FLAGS.output_stride) all_embeddings = None predicted_labels = create_predictions_fast(samples, reference_labels, first_frame_img, model_options) # If you need more options like saving embeddings, replace the call to # create_predictions_fast with create_predictions. tf.train.get_or_create_global_step() saver = tf.train.Saver(slim.get_variables_to_restore()) sv = tf.train.Supervisor(graph=g, logdir=FLAGS.vis_logdir, init_op=tf.global_variables_initializer(), summary_op=None, summary_writer=None, global_step=None, saver=saver) num_batches = int( math.ceil(num_vis_examples / float(FLAGS.vis_batch_size))) last_checkpoint = None # Infinite loop to visualize the results when new checkpoint is created. while True: last_checkpoint = slim.evaluation.wait_for_new_checkpoint( FLAGS.checkpoint_dir, last_checkpoint) start = time.time() tf.logging.info('Starting visualization at ' + time.strftime('%Y-%m-%d-%H:%M:%S', time.gmtime())) tf.logging.info('Visualizing with model %s', last_checkpoint) all_ious = [] with sv.managed_session(FLAGS.master, start_standard_services=False) as sess: sv.start_queue_runners(sess) sv.saver.restore(sess, last_checkpoint) for batch in range(num_batches): ops = [predicted_labels, gt_labels, seq_name] if FLAGS.save_embeddings: ops.append(all_embeddings) tf.logging.info('Visualizing batch %d / %d', batch + 1, num_batches) res = sess.run(ops) tf.logging.info('Forwarding done') pred_labels_val, gt_labels_val, seq_name_val = res[:3] if FLAGS.save_embeddings: all_embeddings_val = res[3] else: all_embeddings_val = None seq_ious = _process_seq_data(segmentation_dir, embeddings_dir, seq_name_val, pred_labels_val, gt_labels_val, all_embeddings_val) all_ious.append(seq_ious) all_ious = np.concatenate(all_ious, axis=0) tf.logging.info('n_seqs %s, mIoU %f', all_ious.shape, all_ious.mean()) tf.logging.info('Finished visualization at ' + time.strftime('%Y-%m-%d-%H:%M:%S', time.gmtime())) result_dir = FLAGS.vis_logdir + '/results/' tf.gfile.MakeDirs(result_dir) with tf.gfile.GFile(result_dir + seq_name_val + '.txt', 'w') as f: f.write(str(all_ious)) if FLAGS.first_frame_finetuning or FLAGS.eval_once_and_quit: break time_to_next_eval = start + FLAGS.eval_interval_secs - time.time() if time_to_next_eval > 0: time.sleep(time_to_next_eval)