def _build_deeplab(inputs_queue_or_samples, outputs_to_num_classes, ignore_label): """Builds a clone of DeepLab. Args: inputs_queue_or_samples: A prefetch queue for images and labels, or directly a dict of the samples. outputs_to_num_classes: A map from output type to the number of classes. For example, for the task of semantic segmentation with 21 semantic classes, we would have outputs_to_num_classes['semantic'] = 21. ignore_label: Ignore label. Returns: A map of maps from output_type (e.g., semantic prediction) to a dictionary of multi-scale logits names to logits. For each output_type, the dictionary has keys which correspond to the scales and values which correspond to the logits. For example, if `scales` equals [1.0, 1.5], then the keys would include 'merged_logits', 'logits_1.00' and 'logits_1.50'. Raises: ValueError: If classification_loss is not softmax, softmax_with_attention, or triplet. """ if hasattr(inputs_queue_or_samples, 'dequeue'): samples = inputs_queue_or_samples.dequeue() else: samples = inputs_queue_or_samples train_crop_size = (None if 0 in FLAGS.train_crop_size else FLAGS.train_crop_size) model_options = common.VideoModelOptions( outputs_to_num_classes=outputs_to_num_classes, crop_size=train_crop_size, atrous_rates=FLAGS.atrous_rates, output_stride=FLAGS.output_stride) if model_options.classification_loss == 'softmax_with_attention': clone_batch_size = FLAGS.train_batch_size // FLAGS.num_clones # Create summaries of ground truth labels. for n in range(clone_batch_size): tf.summary.image( 'gt_label_%d' % n, tf.cast( samples[common.LABEL] [n * FLAGS.train_num_frames_per_video:(n + 1) * FLAGS.train_num_frames_per_video], tf.uint8) * 32, max_outputs=FLAGS.train_num_frames_per_video) if common.PRECEDING_FRAME_LABEL in samples: preceding_frame_label = samples[common.PRECEDING_FRAME_LABEL] init_softmax = [] for n in range(clone_batch_size): init_softmax_n = embedding_utils.create_initial_softmax_from_labels( preceding_frame_label[n, tf.newaxis], samples[common.LABEL][n * FLAGS.train_num_frames_per_video, tf.newaxis], common.parse_decoder_output_stride(), reduce_labels=True) init_softmax_n = tf.squeeze(init_softmax_n, axis=0) init_softmax.append(init_softmax_n) tf.summary.image( 'preceding_frame_label', tf.cast(preceding_frame_label[n, tf.newaxis], tf.uint8) * 32) else: init_softmax = None outputs_to_scales_to_logits = ( model.multi_scale_logits_with_nearest_neighbor_matching( samples[common.IMAGE], model_options=model_options, image_pyramid=FLAGS.image_pyramid, weight_decay=FLAGS.weight_decay, is_training=True, fine_tune_batch_norm=FLAGS.fine_tune_batch_norm, reference_labels=samples[common.LABEL], clone_batch_size=FLAGS.train_batch_size // FLAGS.num_clones, num_frames_per_video=FLAGS.train_num_frames_per_video, embedding_dimension=FLAGS.embedding_dimension, max_neighbors_per_object=FLAGS.train_max_neighbors_per_object, k_nearest_neighbors=FLAGS.k_nearest_neighbors, use_softmax_feedback=FLAGS.use_softmax_feedback, initial_softmax_feedback=init_softmax, embedding_seg_feature_dimension=FLAGS. embedding_seg_feature_dimension, embedding_seg_n_layers=FLAGS.embedding_seg_n_layers, embedding_seg_kernel_size=FLAGS.embedding_seg_kernel_size, embedding_seg_atrous_rates=FLAGS.embedding_seg_atrous_rates, normalize_nearest_neighbor_distances=FLAGS. normalize_nearest_neighbor_distances, also_attend_to_previous_frame=FLAGS. also_attend_to_previous_frame, damage_initial_previous_frame_mask=FLAGS. damage_initial_previous_frame_mask, use_local_previous_frame_attention=FLAGS. use_local_previous_frame_attention, previous_frame_attention_window_size=FLAGS. previous_frame_attention_window_size, use_first_frame_matching=FLAGS.use_first_frame_matching)) else: outputs_to_scales_to_logits = model.multi_scale_logits_v2( samples[common.IMAGE], model_options=model_options, image_pyramid=FLAGS.image_pyramid, weight_decay=FLAGS.weight_decay, is_training=True, fine_tune_batch_norm=FLAGS.fine_tune_batch_norm) if model_options.classification_loss == 'softmax': for output, num_classes in six.iteritems(outputs_to_num_classes): train_utils.add_softmax_cross_entropy_loss_for_each_scale( outputs_to_scales_to_logits[output], samples[common.LABEL], num_classes, ignore_label, loss_weight=1.0, upsample_logits=FLAGS.upsample_logits, scope=output) elif model_options.classification_loss == 'triplet': for output, _ in six.iteritems(outputs_to_num_classes): train_utils.add_triplet_loss_for_each_scale( FLAGS.train_batch_size // FLAGS.num_clones, FLAGS.train_num_frames_per_video, FLAGS.embedding_dimension, outputs_to_scales_to_logits[output], samples[common.LABEL], scope=output) elif model_options.classification_loss == 'softmax_with_attention': labels = samples[common.LABEL] batch_size = FLAGS.train_batch_size // FLAGS.num_clones num_frames_per_video = FLAGS.train_num_frames_per_video h, w = train_utils.resolve_shape(labels)[1:3] labels = tf.reshape( labels, tf.stack([batch_size, num_frames_per_video, h, w, 1])) # Strip the reference labels off. if FLAGS.also_attend_to_previous_frame or FLAGS.use_softmax_feedback: n_ref_frames = 2 else: n_ref_frames = 1 labels = labels[:, n_ref_frames:] # Merge batch and time dimensions. labels = tf.reshape( labels, tf.stack( [batch_size * (num_frames_per_video - n_ref_frames), h, w, 1])) for output, num_classes in six.iteritems(outputs_to_num_classes): train_utils.add_dynamic_softmax_cross_entropy_loss_for_each_scale( outputs_to_scales_to_logits[output], labels, ignore_label, loss_weight=1.0, upsample_logits=FLAGS.upsample_logits, scope=output, top_k_percent_pixels=FLAGS.top_k_percent_pixels, hard_example_mining_step=FLAGS.hard_example_mining_step) else: raise ValueError('Only support softmax, softmax_with_attention' ' or triplet for classification_loss.') return outputs_to_scales_to_logits
def main(unused_argv): if FLAGS.vis_batch_size != 1: raise ValueError('Only batch size 1 is supported for now') data_type = 'tf_sequence_example' # Get dataset-dependent information. dataset = video_dataset.get_dataset(FLAGS.dataset, FLAGS.vis_split, dataset_dir=FLAGS.dataset_dir, data_type=data_type) # Prepare for visualization. tf.gfile.MakeDirs(FLAGS.vis_logdir) segmentation_dir = os.path.join(FLAGS.vis_logdir, _SEGMENTATION_SAVE_FOLDER) tf.gfile.MakeDirs(segmentation_dir) embeddings_dir = os.path.join(FLAGS.vis_logdir, _EMBEDDINGS_SAVE_FOLDER) tf.gfile.MakeDirs(embeddings_dir) num_vis_examples = (dataset.num_videos if (FLAGS.num_vis_examples < 0) else FLAGS.num_vis_examples) if FLAGS.first_frame_finetuning: num_vis_examples = 1 tf.logging.info('Visualizing on %s set', FLAGS.vis_split) g = tf.Graph() with g.as_default(): # Without setting device to CPU we run out of memory. with tf.device('cpu:0'): samples = video_input_generator.get( dataset, None, None, FLAGS.vis_batch_size, min_resize_value=FLAGS.min_resize_value, max_resize_value=FLAGS.max_resize_value, resize_factor=FLAGS.resize_factor, dataset_split=FLAGS.vis_split, is_training=False, model_variant=FLAGS.model_variant, preprocess_image_and_label=False, remap_labels_to_reference_frame=False) samples[common.IMAGE] = tf.cast(samples[common.IMAGE], tf.float32) samples[common.LABEL] = tf.cast(samples[common.LABEL], tf.int32) first_frame_img = samples[common.IMAGE][0] reference_labels = samples[common.LABEL][0, tf.newaxis] gt_labels = tf.squeeze(samples[common.LABEL], axis=-1) seq_name = samples[common.VIDEO_ID][0] model_options = common.VideoModelOptions( outputs_to_num_classes={common.OUTPUT_TYPE: dataset.num_classes}, crop_size=None, atrous_rates=FLAGS.atrous_rates, output_stride=FLAGS.output_stride) all_embeddings = None predicted_labels = create_predictions_fast(samples, reference_labels, first_frame_img, model_options) # If you need more options like saving embeddings, replace the call to # create_predictions_fast with create_predictions. tf.train.get_or_create_global_step() saver = tf.train.Saver(slim.get_variables_to_restore()) sv = tf.train.Supervisor(graph=g, logdir=FLAGS.vis_logdir, init_op=tf.global_variables_initializer(), summary_op=None, summary_writer=None, global_step=None, saver=saver) num_batches = int( math.ceil(num_vis_examples / float(FLAGS.vis_batch_size))) last_checkpoint = None # Infinite loop to visualize the results when new checkpoint is created. while True: last_checkpoint = slim.evaluation.wait_for_new_checkpoint( FLAGS.checkpoint_dir, last_checkpoint) start = time.time() tf.logging.info('Starting visualization at ' + time.strftime('%Y-%m-%d-%H:%M:%S', time.gmtime())) tf.logging.info('Visualizing with model %s', last_checkpoint) all_ious = [] with sv.managed_session(FLAGS.master, start_standard_services=False) as sess: sv.start_queue_runners(sess) sv.saver.restore(sess, last_checkpoint) for batch in range(num_batches): ops = [predicted_labels, gt_labels, seq_name] if FLAGS.save_embeddings: ops.append(all_embeddings) tf.logging.info('Visualizing batch %d / %d', batch + 1, num_batches) res = sess.run(ops) tf.logging.info('Forwarding done') pred_labels_val, gt_labels_val, seq_name_val = res[:3] if FLAGS.save_embeddings: all_embeddings_val = res[3] else: all_embeddings_val = None seq_ious = _process_seq_data(segmentation_dir, embeddings_dir, seq_name_val, pred_labels_val, gt_labels_val, all_embeddings_val) all_ious.append(seq_ious) all_ious = np.concatenate(all_ious, axis=0) tf.logging.info('n_seqs %s, mIoU %f', all_ious.shape, all_ious.mean()) tf.logging.info('Finished visualization at ' + time.strftime('%Y-%m-%d-%H:%M:%S', time.gmtime())) result_dir = FLAGS.vis_logdir + '/results/' tf.gfile.MakeDirs(result_dir) with tf.gfile.GFile(result_dir + seq_name_val + '.txt', 'w') as f: f.write(str(all_ious)) if FLAGS.first_frame_finetuning or FLAGS.eval_once_and_quit: break time_to_next_eval = start + FLAGS.eval_interval_secs - time.time() if time_to_next_eval > 0: time.sleep(time_to_next_eval)