def _build_deeplab(inputs_queue_or_samples, outputs_to_num_classes,
                   ignore_label):
    """Builds a clone of DeepLab.

  Args:
    inputs_queue_or_samples: A prefetch queue for images and labels, or
      directly a dict of the samples.
    outputs_to_num_classes: A map from output type to the number of classes.
      For example, for the task of semantic segmentation with 21 semantic
      classes, we would have outputs_to_num_classes['semantic'] = 21.
    ignore_label: Ignore label.

  Returns:
    A map of maps from output_type (e.g., semantic prediction) to a
      dictionary of multi-scale logits names to logits. For each output_type,
      the dictionary has keys which correspond to the scales and values which
      correspond to the logits. For example, if `scales` equals [1.0, 1.5],
      then the keys would include 'merged_logits', 'logits_1.00' and
      'logits_1.50'.

  Raises:
    ValueError: If classification_loss is not softmax, softmax_with_attention,
      or triplet.
  """
    if hasattr(inputs_queue_or_samples, 'dequeue'):
        samples = inputs_queue_or_samples.dequeue()
    else:
        samples = inputs_queue_or_samples
    train_crop_size = (None if 0 in FLAGS.train_crop_size else
                       FLAGS.train_crop_size)

    model_options = common.VideoModelOptions(
        outputs_to_num_classes=outputs_to_num_classes,
        crop_size=train_crop_size,
        atrous_rates=FLAGS.atrous_rates,
        output_stride=FLAGS.output_stride)

    if model_options.classification_loss == 'softmax_with_attention':
        clone_batch_size = FLAGS.train_batch_size // FLAGS.num_clones

        # Create summaries of ground truth labels.
        for n in range(clone_batch_size):
            tf.summary.image(
                'gt_label_%d' % n,
                tf.cast(
                    samples[common.LABEL]
                    [n * FLAGS.train_num_frames_per_video:(n + 1) *
                     FLAGS.train_num_frames_per_video], tf.uint8) * 32,
                max_outputs=FLAGS.train_num_frames_per_video)

        if common.PRECEDING_FRAME_LABEL in samples:
            preceding_frame_label = samples[common.PRECEDING_FRAME_LABEL]
            init_softmax = []
            for n in range(clone_batch_size):
                init_softmax_n = embedding_utils.create_initial_softmax_from_labels(
                    preceding_frame_label[n, tf.newaxis],
                    samples[common.LABEL][n * FLAGS.train_num_frames_per_video,
                                          tf.newaxis],
                    common.parse_decoder_output_stride(),
                    reduce_labels=True)
                init_softmax_n = tf.squeeze(init_softmax_n, axis=0)
                init_softmax.append(init_softmax_n)
                tf.summary.image(
                    'preceding_frame_label',
                    tf.cast(preceding_frame_label[n, tf.newaxis], tf.uint8) *
                    32)
        else:
            init_softmax = None

        outputs_to_scales_to_logits = (
            model.multi_scale_logits_with_nearest_neighbor_matching(
                samples[common.IMAGE],
                model_options=model_options,
                image_pyramid=FLAGS.image_pyramid,
                weight_decay=FLAGS.weight_decay,
                is_training=True,
                fine_tune_batch_norm=FLAGS.fine_tune_batch_norm,
                reference_labels=samples[common.LABEL],
                clone_batch_size=FLAGS.train_batch_size // FLAGS.num_clones,
                num_frames_per_video=FLAGS.train_num_frames_per_video,
                embedding_dimension=FLAGS.embedding_dimension,
                max_neighbors_per_object=FLAGS.train_max_neighbors_per_object,
                k_nearest_neighbors=FLAGS.k_nearest_neighbors,
                use_softmax_feedback=FLAGS.use_softmax_feedback,
                initial_softmax_feedback=init_softmax,
                embedding_seg_feature_dimension=FLAGS.
                embedding_seg_feature_dimension,
                embedding_seg_n_layers=FLAGS.embedding_seg_n_layers,
                embedding_seg_kernel_size=FLAGS.embedding_seg_kernel_size,
                embedding_seg_atrous_rates=FLAGS.embedding_seg_atrous_rates,
                normalize_nearest_neighbor_distances=FLAGS.
                normalize_nearest_neighbor_distances,
                also_attend_to_previous_frame=FLAGS.
                also_attend_to_previous_frame,
                damage_initial_previous_frame_mask=FLAGS.
                damage_initial_previous_frame_mask,
                use_local_previous_frame_attention=FLAGS.
                use_local_previous_frame_attention,
                previous_frame_attention_window_size=FLAGS.
                previous_frame_attention_window_size,
                use_first_frame_matching=FLAGS.use_first_frame_matching))
    else:
        outputs_to_scales_to_logits = model.multi_scale_logits_v2(
            samples[common.IMAGE],
            model_options=model_options,
            image_pyramid=FLAGS.image_pyramid,
            weight_decay=FLAGS.weight_decay,
            is_training=True,
            fine_tune_batch_norm=FLAGS.fine_tune_batch_norm)

    if model_options.classification_loss == 'softmax':
        for output, num_classes in six.iteritems(outputs_to_num_classes):
            train_utils.add_softmax_cross_entropy_loss_for_each_scale(
                outputs_to_scales_to_logits[output],
                samples[common.LABEL],
                num_classes,
                ignore_label,
                loss_weight=1.0,
                upsample_logits=FLAGS.upsample_logits,
                scope=output)
    elif model_options.classification_loss == 'triplet':
        for output, _ in six.iteritems(outputs_to_num_classes):
            train_utils.add_triplet_loss_for_each_scale(
                FLAGS.train_batch_size // FLAGS.num_clones,
                FLAGS.train_num_frames_per_video,
                FLAGS.embedding_dimension,
                outputs_to_scales_to_logits[output],
                samples[common.LABEL],
                scope=output)
    elif model_options.classification_loss == 'softmax_with_attention':
        labels = samples[common.LABEL]
        batch_size = FLAGS.train_batch_size // FLAGS.num_clones
        num_frames_per_video = FLAGS.train_num_frames_per_video
        h, w = train_utils.resolve_shape(labels)[1:3]
        labels = tf.reshape(
            labels, tf.stack([batch_size, num_frames_per_video, h, w, 1]))
        # Strip the reference labels off.
        if FLAGS.also_attend_to_previous_frame or FLAGS.use_softmax_feedback:
            n_ref_frames = 2
        else:
            n_ref_frames = 1
        labels = labels[:, n_ref_frames:]
        # Merge batch and time dimensions.
        labels = tf.reshape(
            labels,
            tf.stack(
                [batch_size * (num_frames_per_video - n_ref_frames), h, w, 1]))

        for output, num_classes in six.iteritems(outputs_to_num_classes):
            train_utils.add_dynamic_softmax_cross_entropy_loss_for_each_scale(
                outputs_to_scales_to_logits[output],
                labels,
                ignore_label,
                loss_weight=1.0,
                upsample_logits=FLAGS.upsample_logits,
                scope=output,
                top_k_percent_pixels=FLAGS.top_k_percent_pixels,
                hard_example_mining_step=FLAGS.hard_example_mining_step)
    else:
        raise ValueError('Only support softmax, softmax_with_attention'
                         ' or triplet for classification_loss.')

    return outputs_to_scales_to_logits
Exemplo n.º 2
0
def main(unused_argv):
    if FLAGS.vis_batch_size != 1:
        raise ValueError('Only batch size 1 is supported for now')

    data_type = 'tf_sequence_example'
    # Get dataset-dependent information.
    dataset = video_dataset.get_dataset(FLAGS.dataset,
                                        FLAGS.vis_split,
                                        dataset_dir=FLAGS.dataset_dir,
                                        data_type=data_type)

    # Prepare for visualization.
    tf.gfile.MakeDirs(FLAGS.vis_logdir)
    segmentation_dir = os.path.join(FLAGS.vis_logdir,
                                    _SEGMENTATION_SAVE_FOLDER)
    tf.gfile.MakeDirs(segmentation_dir)
    embeddings_dir = os.path.join(FLAGS.vis_logdir, _EMBEDDINGS_SAVE_FOLDER)
    tf.gfile.MakeDirs(embeddings_dir)
    num_vis_examples = (dataset.num_videos if (FLAGS.num_vis_examples < 0) else
                        FLAGS.num_vis_examples)
    if FLAGS.first_frame_finetuning:
        num_vis_examples = 1

    tf.logging.info('Visualizing on %s set', FLAGS.vis_split)
    g = tf.Graph()
    with g.as_default():
        # Without setting device to CPU we run out of memory.
        with tf.device('cpu:0'):
            samples = video_input_generator.get(
                dataset,
                None,
                None,
                FLAGS.vis_batch_size,
                min_resize_value=FLAGS.min_resize_value,
                max_resize_value=FLAGS.max_resize_value,
                resize_factor=FLAGS.resize_factor,
                dataset_split=FLAGS.vis_split,
                is_training=False,
                model_variant=FLAGS.model_variant,
                preprocess_image_and_label=False,
                remap_labels_to_reference_frame=False)
            samples[common.IMAGE] = tf.cast(samples[common.IMAGE], tf.float32)
            samples[common.LABEL] = tf.cast(samples[common.LABEL], tf.int32)
            first_frame_img = samples[common.IMAGE][0]
            reference_labels = samples[common.LABEL][0, tf.newaxis]
            gt_labels = tf.squeeze(samples[common.LABEL], axis=-1)
            seq_name = samples[common.VIDEO_ID][0]

        model_options = common.VideoModelOptions(
            outputs_to_num_classes={common.OUTPUT_TYPE: dataset.num_classes},
            crop_size=None,
            atrous_rates=FLAGS.atrous_rates,
            output_stride=FLAGS.output_stride)

        all_embeddings = None
        predicted_labels = create_predictions_fast(samples, reference_labels,
                                                   first_frame_img,
                                                   model_options)
        # If you need more options like saving embeddings, replace the call to
        # create_predictions_fast with create_predictions.

        tf.train.get_or_create_global_step()
        saver = tf.train.Saver(slim.get_variables_to_restore())
        sv = tf.train.Supervisor(graph=g,
                                 logdir=FLAGS.vis_logdir,
                                 init_op=tf.global_variables_initializer(),
                                 summary_op=None,
                                 summary_writer=None,
                                 global_step=None,
                                 saver=saver)
        num_batches = int(
            math.ceil(num_vis_examples / float(FLAGS.vis_batch_size)))
        last_checkpoint = None

        # Infinite loop to visualize the results when new checkpoint is created.
        while True:
            last_checkpoint = slim.evaluation.wait_for_new_checkpoint(
                FLAGS.checkpoint_dir, last_checkpoint)
            start = time.time()
            tf.logging.info('Starting visualization at ' +
                            time.strftime('%Y-%m-%d-%H:%M:%S', time.gmtime()))
            tf.logging.info('Visualizing with model %s', last_checkpoint)

            all_ious = []
            with sv.managed_session(FLAGS.master,
                                    start_standard_services=False) as sess:
                sv.start_queue_runners(sess)
                sv.saver.restore(sess, last_checkpoint)

                for batch in range(num_batches):
                    ops = [predicted_labels, gt_labels, seq_name]
                    if FLAGS.save_embeddings:
                        ops.append(all_embeddings)
                    tf.logging.info('Visualizing batch %d / %d', batch + 1,
                                    num_batches)
                    res = sess.run(ops)
                    tf.logging.info('Forwarding done')
                    pred_labels_val, gt_labels_val, seq_name_val = res[:3]
                    if FLAGS.save_embeddings:
                        all_embeddings_val = res[3]
                    else:
                        all_embeddings_val = None
                    seq_ious = _process_seq_data(segmentation_dir,
                                                 embeddings_dir, seq_name_val,
                                                 pred_labels_val,
                                                 gt_labels_val,
                                                 all_embeddings_val)
                    all_ious.append(seq_ious)
            all_ious = np.concatenate(all_ious, axis=0)
            tf.logging.info('n_seqs %s, mIoU %f', all_ious.shape,
                            all_ious.mean())
            tf.logging.info('Finished visualization at ' +
                            time.strftime('%Y-%m-%d-%H:%M:%S', time.gmtime()))
            result_dir = FLAGS.vis_logdir + '/results/'
            tf.gfile.MakeDirs(result_dir)
            with tf.gfile.GFile(result_dir + seq_name_val + '.txt', 'w') as f:
                f.write(str(all_ious))
            if FLAGS.first_frame_finetuning or FLAGS.eval_once_and_quit:
                break
            time_to_next_eval = start + FLAGS.eval_interval_secs - time.time()
            if time_to_next_eval > 0:
                time.sleep(time_to_next_eval)