Ejemplo n.º 1
0
def _has_foreground_and_background_in_first_frame_2(label,
                                                    decoder_output_stride):
  """Checks if the labels have foreground and background in the first frame.

  Second attempt, this time we use the actual output dimension for resizing.

  Args:
    label: Label tensor of shape [num_frames, height, width, 1].
    decoder_output_stride: Integer, the stride of the decoder output.

  Returns:
    Boolean, whether the labels have foreground and background in the first
      frame.
  """
  h, w = train_utils.resolve_shape(label)[1:3]
  h_sub = model.scale_dimension(h, 1.0 / decoder_output_stride)
  w_sub = model.scale_dimension(w, 1.0 / decoder_output_stride)
  label_downscaled = tf.squeeze(
      tf.image.resize_nearest_neighbor(label[0, tf.newaxis], [h_sub, w_sub],
                                       align_corners=True), axis=0)
  is_bg = tf.equal(label_downscaled, 0)
  is_fg = tf.logical_not(is_bg)
  # Just using reduce_any was not robust enough, so lets make sure the count
  # is above MIN_LABEL_COUNT.
  fg_count = tf.reduce_sum(tf.cast(is_fg, tf.int32))
  bg_count = tf.reduce_sum(tf.cast(is_bg, tf.int32))
  has_bg = tf.greater_equal(fg_count, MIN_LABEL_COUNT)
  has_fg = tf.greater_equal(bg_count, MIN_LABEL_COUNT)
  return tf.logical_and(has_bg, has_fg)
Ejemplo n.º 2
0
def _has_enough_pixels_of_each_object_in_first_frame(
    label, decoder_output_stride):
  """Checks if for each object (incl. background) enough pixels are visible.

  During test time, we will usually not see a reference frame in which only
  very few pixels of one object are visible. These cases can be problematic
  during training, especially if more than the 1-nearest neighbor is used.
  That's why this function can be used to detect and filter these cases.

  Args:
    label: Label tensor of shape [num_frames, height, width, 1].
    decoder_output_stride: Integer, the stride of the decoder output.

  Returns:
    Boolean, whether the labels have enough pixels of each object in the first
      frame.
  """
  h, w = train_utils.resolve_shape(label)[1:3]
  h_sub = model.scale_dimension(h, 1.0 / decoder_output_stride)
  w_sub = model.scale_dimension(w, 1.0 / decoder_output_stride)
  label_downscaled = tf.squeeze(
      tf.image.resize_nearest_neighbor(label[0, tf.newaxis], [h_sub, w_sub],
                                       align_corners=True), axis=0)
  _, _, counts = tf.unique_with_counts(
      tf.reshape(label_downscaled, [-1]))
  has_enough_pixels_per_object = tf.reduce_all(
      tf.greater_equal(counts, MIN_LABEL_COUNT))
  return has_enough_pixels_per_object
Ejemplo n.º 3
0
def _has_enough_pixels_of_each_object_in_first_frame(label,
                                                     decoder_output_stride):
    """Checks if for each object (incl. background) enough pixels are visible.

  During test time, we will usually not see a reference frame in which only
  very few pixels of one object are visible. These cases can be problematic
  during training, especially if more than the 1-nearest neighbor is used.
  That's why this function can be used to detect and filter these cases.

  Args:
    label: Label tensor of shape [num_frames, height, width, 1].
    decoder_output_stride: Integer, the stride of the decoder output.

  Returns:
    Boolean, whether the labels have enough pixels of each object in the first
      frame.
  """
    h, w = train_utils.resolve_shape(label)[1:3]
    h_sub = model.scale_dimension(h, 1.0 / decoder_output_stride)
    w_sub = model.scale_dimension(w, 1.0 / decoder_output_stride)
    label_downscaled = tf.squeeze(tf.image.resize_nearest_neighbor(
        label[0, tf.newaxis], [h_sub, w_sub], align_corners=True),
                                  axis=0)
    _, _, counts = tf.unique_with_counts(tf.reshape(label_downscaled, [-1]))
    has_enough_pixels_per_object = tf.reduce_all(
        tf.greater_equal(counts, MIN_LABEL_COUNT))
    return has_enough_pixels_per_object
Ejemplo n.º 4
0
def _has_foreground_and_background_in_first_frame(label, subsampling_factor):
  """Checks if the labels have foreground and background in the first frame.

  Args:
    label: Label tensor of shape [num_frames, height, width, 1].
    subsampling_factor: Integer, the subsampling factor.

  Returns:
    Boolean, whether the labels have foreground and background in the first
      frame.
  """
  h, w = train_utils.resolve_shape(label)[1:3]
  label_downscaled = tf.squeeze(
      tf.image.resize_nearest_neighbor(label[0, tf.newaxis],
                                       [h // subsampling_factor,
                                        w // subsampling_factor],
                                       align_corners=True),
      axis=0)
  is_bg = tf.equal(label_downscaled, 0)
  is_fg = tf.logical_not(is_bg)
  # Just using reduce_any was not robust enough, so lets make sure the count
  # is above MIN_LABEL_COUNT.
  fg_count = tf.reduce_sum(tf.cast(is_fg, tf.int32))
  bg_count = tf.reduce_sum(tf.cast(is_bg, tf.int32))
  has_bg = tf.greater_equal(fg_count, MIN_LABEL_COUNT)
  has_fg = tf.greater_equal(bg_count, MIN_LABEL_COUNT)
  return tf.logical_and(has_bg, has_fg)
Ejemplo n.º 5
0
def _has_foreground_and_background_in_first_frame_2(label,
                                                    decoder_output_stride):
    """Checks if the labels have foreground and background in the first frame.

  Second attempt, this time we use the actual output dimension for resizing.

  Args:
    label: Label tensor of shape [num_frames, height, width, 1].
    decoder_output_stride: Integer, the stride of the decoder output.

  Returns:
    Boolean, whether the labels have foreground and background in the first
      frame.
  """
    h, w = train_utils.resolve_shape(label)[1:3]
    h_sub = model.scale_dimension(h, 1.0 / decoder_output_stride)
    w_sub = model.scale_dimension(w, 1.0 / decoder_output_stride)
    label_downscaled = tf.squeeze(tf.image.resize_nearest_neighbor(
        label[0, tf.newaxis], [h_sub, w_sub], align_corners=True),
                                  axis=0)
    is_bg = tf.equal(label_downscaled, 0)
    is_fg = tf.logical_not(is_bg)
    # Just using reduce_any was not robust enough, so lets make sure the count
    # is above MIN_LABEL_COUNT.
    fg_count = tf.reduce_sum(tf.cast(is_fg, tf.int32))
    bg_count = tf.reduce_sum(tf.cast(is_bg, tf.int32))
    has_bg = tf.greater_equal(fg_count, MIN_LABEL_COUNT)
    has_fg = tf.greater_equal(bg_count, MIN_LABEL_COUNT)
    return tf.logical_and(has_bg, has_fg)
Ejemplo n.º 6
0
def _has_foreground_and_background_in_first_frame(label, subsampling_factor):
    """Checks if the labels have foreground and background in the first frame.

  Args:
    label: Label tensor of shape [num_frames, height, width, 1].
    subsampling_factor: Integer, the subsampling factor.

  Returns:
    Boolean, whether the labels have foreground and background in the first
      frame.
  """
    h, w = train_utils.resolve_shape(label)[1:3]
    label_downscaled = tf.squeeze(tf.image.resize_nearest_neighbor(
        label[0,
              tf.newaxis], [h // subsampling_factor, w // subsampling_factor],
        align_corners=True),
                                  axis=0)
    is_bg = tf.equal(label_downscaled, 0)
    is_fg = tf.logical_not(is_bg)
    # Just using reduce_any was not robust enough, so lets make sure the count
    # is above MIN_LABEL_COUNT.
    fg_count = tf.reduce_sum(tf.cast(is_fg, tf.int32))
    bg_count = tf.reduce_sum(tf.cast(is_bg, tf.int32))
    has_bg = tf.greater_equal(fg_count, MIN_LABEL_COUNT)
    has_fg = tf.greater_equal(bg_count, MIN_LABEL_COUNT)
    return tf.logical_and(has_bg, has_fg)
def _build_deeplab(inputs_queue_or_samples, outputs_to_num_classes,
                   ignore_label):
    """Builds a clone of DeepLab.

  Args:
    inputs_queue_or_samples: A prefetch queue for images and labels, or
      directly a dict of the samples.
    outputs_to_num_classes: A map from output type to the number of classes.
      For example, for the task of semantic segmentation with 21 semantic
      classes, we would have outputs_to_num_classes['semantic'] = 21.
    ignore_label: Ignore label.

  Returns:
    A map of maps from output_type (e.g., semantic prediction) to a
      dictionary of multi-scale logits names to logits. For each output_type,
      the dictionary has keys which correspond to the scales and values which
      correspond to the logits. For example, if `scales` equals [1.0, 1.5],
      then the keys would include 'merged_logits', 'logits_1.00' and
      'logits_1.50'.

  Raises:
    ValueError: If classification_loss is not softmax, softmax_with_attention,
      or triplet.
  """
    if hasattr(inputs_queue_or_samples, 'dequeue'):
        samples = inputs_queue_or_samples.dequeue()
    else:
        samples = inputs_queue_or_samples
    train_crop_size = (None if 0 in FLAGS.train_crop_size else
                       FLAGS.train_crop_size)

    model_options = common.VideoModelOptions(
        outputs_to_num_classes=outputs_to_num_classes,
        crop_size=train_crop_size,
        atrous_rates=FLAGS.atrous_rates,
        output_stride=FLAGS.output_stride)

    if model_options.classification_loss == 'softmax_with_attention':
        clone_batch_size = FLAGS.train_batch_size // FLAGS.num_clones

        # Create summaries of ground truth labels.
        for n in range(clone_batch_size):
            tf.summary.image(
                'gt_label_%d' % n,
                tf.cast(
                    samples[common.LABEL]
                    [n * FLAGS.train_num_frames_per_video:(n + 1) *
                     FLAGS.train_num_frames_per_video], tf.uint8) * 32,
                max_outputs=FLAGS.train_num_frames_per_video)

        if common.PRECEDING_FRAME_LABEL in samples:
            preceding_frame_label = samples[common.PRECEDING_FRAME_LABEL]
            init_softmax = []
            for n in range(clone_batch_size):
                init_softmax_n = embedding_utils.create_initial_softmax_from_labels(
                    preceding_frame_label[n, tf.newaxis],
                    samples[common.LABEL][n * FLAGS.train_num_frames_per_video,
                                          tf.newaxis],
                    common.parse_decoder_output_stride(),
                    reduce_labels=True)
                init_softmax_n = tf.squeeze(init_softmax_n, axis=0)
                init_softmax.append(init_softmax_n)
                tf.summary.image(
                    'preceding_frame_label',
                    tf.cast(preceding_frame_label[n, tf.newaxis], tf.uint8) *
                    32)
        else:
            init_softmax = None

        outputs_to_scales_to_logits = (
            model.multi_scale_logits_with_nearest_neighbor_matching(
                samples[common.IMAGE],
                model_options=model_options,
                image_pyramid=FLAGS.image_pyramid,
                weight_decay=FLAGS.weight_decay,
                is_training=True,
                fine_tune_batch_norm=FLAGS.fine_tune_batch_norm,
                reference_labels=samples[common.LABEL],
                clone_batch_size=FLAGS.train_batch_size // FLAGS.num_clones,
                num_frames_per_video=FLAGS.train_num_frames_per_video,
                embedding_dimension=FLAGS.embedding_dimension,
                max_neighbors_per_object=FLAGS.train_max_neighbors_per_object,
                k_nearest_neighbors=FLAGS.k_nearest_neighbors,
                use_softmax_feedback=FLAGS.use_softmax_feedback,
                initial_softmax_feedback=init_softmax,
                embedding_seg_feature_dimension=FLAGS.
                embedding_seg_feature_dimension,
                embedding_seg_n_layers=FLAGS.embedding_seg_n_layers,
                embedding_seg_kernel_size=FLAGS.embedding_seg_kernel_size,
                embedding_seg_atrous_rates=FLAGS.embedding_seg_atrous_rates,
                normalize_nearest_neighbor_distances=FLAGS.
                normalize_nearest_neighbor_distances,
                also_attend_to_previous_frame=FLAGS.
                also_attend_to_previous_frame,
                damage_initial_previous_frame_mask=FLAGS.
                damage_initial_previous_frame_mask,
                use_local_previous_frame_attention=FLAGS.
                use_local_previous_frame_attention,
                previous_frame_attention_window_size=FLAGS.
                previous_frame_attention_window_size,
                use_first_frame_matching=FLAGS.use_first_frame_matching))
    else:
        outputs_to_scales_to_logits = model.multi_scale_logits_v2(
            samples[common.IMAGE],
            model_options=model_options,
            image_pyramid=FLAGS.image_pyramid,
            weight_decay=FLAGS.weight_decay,
            is_training=True,
            fine_tune_batch_norm=FLAGS.fine_tune_batch_norm)

    if model_options.classification_loss == 'softmax':
        for output, num_classes in six.iteritems(outputs_to_num_classes):
            train_utils.add_softmax_cross_entropy_loss_for_each_scale(
                outputs_to_scales_to_logits[output],
                samples[common.LABEL],
                num_classes,
                ignore_label,
                loss_weight=1.0,
                upsample_logits=FLAGS.upsample_logits,
                scope=output)
    elif model_options.classification_loss == 'triplet':
        for output, _ in six.iteritems(outputs_to_num_classes):
            train_utils.add_triplet_loss_for_each_scale(
                FLAGS.train_batch_size // FLAGS.num_clones,
                FLAGS.train_num_frames_per_video,
                FLAGS.embedding_dimension,
                outputs_to_scales_to_logits[output],
                samples[common.LABEL],
                scope=output)
    elif model_options.classification_loss == 'softmax_with_attention':
        labels = samples[common.LABEL]
        batch_size = FLAGS.train_batch_size // FLAGS.num_clones
        num_frames_per_video = FLAGS.train_num_frames_per_video
        h, w = train_utils.resolve_shape(labels)[1:3]
        labels = tf.reshape(
            labels, tf.stack([batch_size, num_frames_per_video, h, w, 1]))
        # Strip the reference labels off.
        if FLAGS.also_attend_to_previous_frame or FLAGS.use_softmax_feedback:
            n_ref_frames = 2
        else:
            n_ref_frames = 1
        labels = labels[:, n_ref_frames:]
        # Merge batch and time dimensions.
        labels = tf.reshape(
            labels,
            tf.stack(
                [batch_size * (num_frames_per_video - n_ref_frames), h, w, 1]))

        for output, num_classes in six.iteritems(outputs_to_num_classes):
            train_utils.add_dynamic_softmax_cross_entropy_loss_for_each_scale(
                outputs_to_scales_to_logits[output],
                labels,
                ignore_label,
                loss_weight=1.0,
                upsample_logits=FLAGS.upsample_logits,
                scope=output,
                top_k_percent_pixels=FLAGS.top_k_percent_pixels,
                hard_example_mining_step=FLAGS.hard_example_mining_step)
    else:
        raise ValueError('Only support softmax, softmax_with_attention'
                         ' or triplet for classification_loss.')

    return outputs_to_scales_to_logits
Ejemplo n.º 8
0
def predict_labels(images,
                   model_options,
                   image_pyramid=None,
                   reference_labels=None,
                   k_nearest_neighbors=1,
                   embedding_dimension=None,
                   use_softmax_feedback=False,
                   initial_softmax_feedback=None,
                   embedding_seg_feature_dimension=256,
                   embedding_seg_n_layers=4,
                   embedding_seg_kernel_size=7,
                   embedding_seg_atrous_rates=None,
                   also_return_softmax_probabilities=False,
                   num_frames_per_video=None,
                   normalize_nearest_neighbor_distances=False,
                   also_attend_to_previous_frame=False,
                   use_local_previous_frame_attention=False,
                   previous_frame_attention_window_size=9,
                   use_first_frame_matching=True,
                   also_return_embeddings=False,
                   ref_embeddings=None):
    """Predicts segmentation labels.

  Args:
    images: A tensor of size [batch, height, width, channels].
    model_options: An InternalModelOptions instance to configure models.
    image_pyramid: Input image scales for multi-scale feature extraction.
    reference_labels: A tensor of size [batch, height, width, 1].
      ground truth labels used to perform a nearest neighbor query
    k_nearest_neighbors: Integer, the number of neighbors to use for nearest
      neighbor queries.
    embedding_dimension: Integer, the dimension used for the learned embedding.
    use_softmax_feedback: Boolean, whether to give the softmax predictions of
      the last frame as additional input to the segmentation head.
    initial_softmax_feedback: Float32 tensor, or None. Can be used to
      initialize the softmax predictions used for the feedback loop.
      Typically only useful for inference. Only has an effect if
      use_softmax_feedback is True.
    embedding_seg_feature_dimension: Integer, the dimensionality used in the
      segmentation head layers.
    embedding_seg_n_layers: Integer, the number of layers in the segmentation
      head.
    embedding_seg_kernel_size: Integer, the kernel size used in the
      segmentation head.
    embedding_seg_atrous_rates: List of integers of length
      embedding_seg_n_layers, the atrous rates to use for the segmentation head.
    also_return_softmax_probabilities: Boolean, if true, additionally return
      the softmax probabilities as second return value.
    num_frames_per_video: Integer, the number of frames per video.
    normalize_nearest_neighbor_distances: Boolean, whether to normalize the
      nearest neighbor distances to [0,1] using sigmoid, scale and shift.
    also_attend_to_previous_frame: Boolean, whether to also use nearest
      neighbor attention with respect to the previous frame.
    use_local_previous_frame_attention: Boolean, whether to restrict the
      previous frame attention to a local search window.
      Only has an effect, if also_attend_to_previous_frame is True.
    previous_frame_attention_window_size: Integer, the window size used for
      local previous frame attention, if use_local_previous_frame_attention
      is True.
    use_first_frame_matching: Boolean, whether to extract features by matching
      to the reference frame. This should always be true except for ablation
      experiments.
    also_return_embeddings: Boolean, whether to return the embeddings as well.
    ref_embeddings: Tuple of
      (first_frame_embeddings, previous_frame_embeddings),
      each of shape [batch, height, width, embedding_dimension], or None.

  Returns:
    A dictionary with keys specifying the output_type (e.g., semantic
      prediction) and values storing Tensors representing predictions (argmax
      over channels). Each prediction has size [batch, height, width].
    If also_return_softmax_probabilities is True, the second return value are
      the softmax probabilities.
    If also_return_embeddings is True, it will also return an embeddings
      tensor of shape [batch, height, width, embedding_dimension].

  Raises:
    ValueError: If classification_loss is not softmax, softmax_with_attention,
      nor triplet.
  """
    if (model_options.classification_loss == 'triplet'
            and reference_labels is None):
        raise ValueError('Need reference_labels for triplet loss')

    if model_options.classification_loss == 'softmax_with_attention':
        if embedding_dimension is None:
            raise ValueError(
                'Need embedding_dimension for softmax_with_attention '
                'loss')
        if reference_labels is None:
            raise ValueError(
                'Need reference_labels for softmax_with_attention loss')
        res = (multi_scale_logits_with_nearest_neighbor_matching(
            images,
            model_options=model_options,
            image_pyramid=image_pyramid,
            is_training=False,
            reference_labels=reference_labels,
            clone_batch_size=1,
            num_frames_per_video=num_frames_per_video,
            embedding_dimension=embedding_dimension,
            max_neighbors_per_object=0,
            k_nearest_neighbors=k_nearest_neighbors,
            use_softmax_feedback=use_softmax_feedback,
            initial_softmax_feedback=initial_softmax_feedback,
            embedding_seg_feature_dimension=embedding_seg_feature_dimension,
            embedding_seg_n_layers=embedding_seg_n_layers,
            embedding_seg_kernel_size=embedding_seg_kernel_size,
            embedding_seg_atrous_rates=embedding_seg_atrous_rates,
            normalize_nearest_neighbor_distances=
            normalize_nearest_neighbor_distances,
            also_attend_to_previous_frame=also_attend_to_previous_frame,
            use_local_previous_frame_attention=
            use_local_previous_frame_attention,
            previous_frame_attention_window_size=
            previous_frame_attention_window_size,
            use_first_frame_matching=use_first_frame_matching,
            also_return_embeddings=also_return_embeddings,
            ref_embeddings=ref_embeddings))
        if also_return_embeddings:
            outputs_to_scales_to_logits, embeddings = res
        else:
            outputs_to_scales_to_logits = res
            embeddings = None
    else:
        outputs_to_scales_to_logits = multi_scale_logits_v2(
            images,
            model_options=model_options,
            image_pyramid=image_pyramid,
            is_training=False,
            fine_tune_batch_norm=False)

    predictions = {}
    for output in sorted(outputs_to_scales_to_logits):
        scales_to_logits = outputs_to_scales_to_logits[output]
        original_logits = scales_to_logits[MERGED_LOGITS_SCOPE]
        if isinstance(original_logits, list):
            assert len(original_logits) == 1
            original_logits = original_logits[0]
        logits = tf.image.resize_bilinear(original_logits,
                                          tf.shape(images)[1:3],
                                          align_corners=True)
        if model_options.classification_loss in ('softmax',
                                                 'softmax_with_attention'):
            predictions[output] = tf.argmax(logits, 3)
        elif model_options.classification_loss == 'triplet':
            # to keep this fast, we do the nearest neighbor assignment on the
            # resolution at which the embedding is extracted and scale the result up
            # afterwards
            embeddings = original_logits
            reference_labels_logits_size = tf.squeeze(
                tf.image.resize_nearest_neighbor(
                    reference_labels[tf.newaxis],
                    train_utils.resolve_shape(embeddings)[1:3],
                    align_corners=True),
                axis=0)
            nn_labels = embedding_utils.assign_labels_by_nearest_neighbors(
                embeddings[0], embeddings[1:], reference_labels_logits_size,
                k_nearest_neighbors)
            predictions[common.OUTPUT_TYPE] = tf.image.resize_nearest_neighbor(
                nn_labels, tf.shape(images)[1:3], align_corners=True)
        else:
            raise ValueError(
                'Only support softmax, triplet, or softmax_with_attention for '
                'classification_loss.')

    if also_return_embeddings:
        assert also_return_softmax_probabilities
        return predictions, tf.nn.softmax(original_logits, axis=-1), embeddings
    elif also_return_softmax_probabilities:
        return predictions, tf.nn.softmax(original_logits, axis=-1)
    else:
        return predictions
Ejemplo n.º 9
0
def get(dataset,
        num_frames_per_video,
        crop_size,
        batch_size,
        min_resize_value=None,
        max_resize_value=None,
        resize_factor=None,
        min_scale_factor=1.,
        max_scale_factor=1.,
        scale_factor_step_size=0,
        preprocess_image_and_label=True,
        num_readers=1,
        num_threads=1,
        dataset_split=None,
        is_training=True,
        model_variant=None,
        batch_capacity_factor=32,
        video_frames_are_decoded=False,
        decoder_output_stride=None,
        first_frame_finetuning=False,
        sample_only_first_frame_for_finetuning=False,
        sample_adjacent_and_consistent_query_frames=False,
        remap_labels_to_reference_frame=True,
        generate_prev_frame_mask_by_mask_damaging=False,
        three_frame_dataset=False,
        add_prev_frame_label=True):
  """Gets the dataset split for semantic segmentation.

  This functions gets the dataset split for semantic segmentation. In
  particular, it is a wrapper of (1) dataset_data_provider which returns the raw
  dataset split, (2) input_preprcess which preprocess the raw data, and (3) the
  Tensorflow operation of batching the preprocessed data. Then, the output could
  be directly used by training, evaluation or visualization.

  Args:
    dataset: An instance of slim Dataset.
    num_frames_per_video: The number of frames used per video
    crop_size: Image crop size [height, width].
    batch_size: Batch size.
    min_resize_value: Desired size of the smaller image side.
    max_resize_value: Maximum allowed size of the larger image side.
    resize_factor: Resized dimensions are multiple of factor plus one.
    min_scale_factor: Minimum scale factor value.
    max_scale_factor: Maximum scale factor value.
    scale_factor_step_size: The step size from min scale factor to max scale
      factor. The input is randomly scaled based on the value of
      (min_scale_factor, max_scale_factor, scale_factor_step_size).
    preprocess_image_and_label: Boolean variable specifies if preprocessing of
      image and label will be performed or not.
    num_readers: Number of readers for data provider.
    num_threads: Number of threads for batching data.
    dataset_split: Dataset split.
    is_training: Is training or not.
    model_variant: Model variant (string) for choosing how to mean-subtract the
      images. See feature_extractor.network_map for supported model variants.
    batch_capacity_factor: Batch capacity factor affecting the training queue
      batch capacity.
    video_frames_are_decoded: Boolean, whether the video frames are already
        decoded
    decoder_output_stride: Integer, the stride of the decoder output.
    first_frame_finetuning: Boolean, whether to only sample the first frame
      for fine-tuning.
    sample_only_first_frame_for_finetuning: Boolean, whether to only sample the
      first frame during fine-tuning. This should be False when using lucid or
      wonderland data, but true when fine-tuning on the first frame only.
      Only has an effect if first_frame_finetuning is True.
    sample_adjacent_and_consistent_query_frames: Boolean, if true, the query
      frames (all but the first frame which is the reference frame) will be
      sampled such that they are adjacent video frames and have the same
      crop coordinates and flip augmentation.
    remap_labels_to_reference_frame: Boolean, whether to remap the labels of
      the query frames to match the labels of the (downscaled) reference frame.
      If a query frame contains a label which is not present in the reference,
      it will be mapped to background.
    generate_prev_frame_mask_by_mask_damaging: Boolean, whether to generate
      the masks used as guidance from the previous frame by damaging the
      ground truth mask.
    three_frame_dataset: Boolean, whether the dataset has exactly three frames
      per video of which the first is to be used as reference and the two
      others are consecutive frames to be used as query frames.
    add_prev_frame_label: Boolean, whether to sample one more frame before the
      first query frame to obtain a previous frame label. Only has an effect,
      if sample_adjacent_and_consistent_query_frames is True and
      generate_prev_frame_mask_by_mask_damaging is False.

  Returns:
    A dictionary of batched Tensors for semantic segmentation.

  Raises:
    ValueError: dataset_split is None, or Failed to find labels.
  """
  if dataset_split is None:
    raise ValueError('Unknown dataset split.')
  if model_variant is None:
    tf.logging.warning('Please specify a model_variant. See '
                       'feature_extractor.network_map for supported model '
                       'variants.')

  data_provider = dataset_data_provider.DatasetDataProvider(
      dataset,
      num_readers=num_readers,
      num_epochs=None if is_training else 1,
      shuffle=is_training)
  image, label, object_label, image_name, height, width, video_id = _get_data(
      data_provider, dataset_split, video_frames_are_decoded)

  sampling_is_valid = tf.constant(True)
  if num_frames_per_video is not None:
    total_num_frames = tf.shape(image)[0]
    if first_frame_finetuning or three_frame_dataset:
      if sample_only_first_frame_for_finetuning:
        assert not sample_adjacent_and_consistent_query_frames, (
            'this option does not make sense for sampling only first frame.')
        # Sample the first frame num_frames_per_video times.
        sel_indices = tf.tile(tf.constant(0, dtype=tf.int32)[tf.newaxis],
                              multiples=[num_frames_per_video])
      else:
        if sample_adjacent_and_consistent_query_frames:
          if add_prev_frame_label:
            num_frames_per_video += 1
          # Since this is first frame fine-tuning, we'll for now assume that
          # each sequence has exactly 3 images: the ref frame and 2 adjacent
          # query frames.
          assert num_frames_per_video == 3
          with tf.control_dependencies([tf.assert_equal(total_num_frames, 3)]):
            sel_indices = tf.constant([1, 2], dtype=tf.int32)
        else:
          # Sample num_frames_per_video - 1 query frames which are not the
          # first frame.
          sel_indices = tf.random_shuffle(
              tf.range(1, total_num_frames))[:(num_frames_per_video - 1)]
        # Concat first frame as reference frame to the front.
        sel_indices = tf.concat([tf.constant(0, dtype=tf.int32)[tf.newaxis],
                                 sel_indices], axis=0)
    else:
      if sample_adjacent_and_consistent_query_frames:
        if add_prev_frame_label:
          # Sample one more frame which we can use to provide initial softmax
          # feedback.
          num_frames_per_video += 1
        ref_idx = tf.random_shuffle(tf.range(total_num_frames))[0]
        sampling_is_valid = tf.greater_equal(total_num_frames,
                                             num_frames_per_video)
        def sample_query_start_idx():
          return tf.random_shuffle(
              tf.range(total_num_frames - num_frames_per_video + 1))[0]
        query_start_idx = tf.cond(sampling_is_valid, sample_query_start_idx,
                                  lambda: tf.constant(0, dtype=tf.int32))
        def sample_sel_indices():
          return tf.concat(
              [ref_idx[tf.newaxis],
               tf.range(
                   query_start_idx,
                   query_start_idx + (num_frames_per_video - 1))], axis=0)
        sel_indices = tf.cond(
            sampling_is_valid, sample_sel_indices,
            lambda: tf.zeros((num_frames_per_video,), dtype=tf.int32))
      else:
        # Randomly sample some frames from the video.
        sel_indices = tf.random_shuffle(
            tf.range(total_num_frames))[:num_frames_per_video]
    image = tf.gather(image, sel_indices, axis=0)
  if not video_frames_are_decoded:
    image = decode_image_sequence(image)

  if label is not None:
    if num_frames_per_video is not None:
      label = tf.gather(label, sel_indices, axis=0)
    if not video_frames_are_decoded:
      label = decode_image_sequence(label, image_format='png', channels=1)

    # Sometimes, label is saved as [num_frames_per_video, height, width] or
    # [num_frames_per_video, height, width, 1]. We change it to be
    # [num_frames_per_video, height, width, 1].
    if label.shape.ndims == 3:
      label = tf.expand_dims(label, 3)
    elif label.shape.ndims == 4 and label.shape.dims[3] == 1:
      pass
    else:
      raise ValueError('Input label shape must be '
                       '[num_frames_per_video, height, width],'
                       ' or [num_frames, height, width, 1]. '
                       'Got {}'.format(label.shape.ndims))
    label.set_shape([None, None, None, 1])

  # Add size of first dimension since tf can't figure it out automatically.
  image.set_shape((num_frames_per_video, None, None, None))
  if label is not None:
    label.set_shape((num_frames_per_video, None, None, None))

  preceding_frame_label = None
  if preprocess_image_and_label:
    if num_frames_per_video is None:
      raise ValueError('num_frame_per_video must be specified for preproc.')
    original_images = []
    images = []
    labels = []
    if sample_adjacent_and_consistent_query_frames:
      num_frames_individual_preproc = 1
    else:
      num_frames_individual_preproc = num_frames_per_video
    for frame_idx in range(num_frames_individual_preproc):
      original_image_t, image_t, label_t = (
          input_preprocess.preprocess_image_and_label(
              image[frame_idx],
              label[frame_idx],
              crop_height=crop_size[0] if crop_size is not None else None,
              crop_width=crop_size[1] if crop_size is not None else None,
              min_resize_value=min_resize_value,
              max_resize_value=max_resize_value,
              resize_factor=resize_factor,
              min_scale_factor=min_scale_factor,
              max_scale_factor=max_scale_factor,
              scale_factor_step_size=scale_factor_step_size,
              ignore_label=dataset.ignore_label,
              is_training=is_training,
              model_variant=model_variant))
      original_images.append(original_image_t)
      images.append(image_t)
      labels.append(label_t)
    if sample_adjacent_and_consistent_query_frames:
      imgs_for_preproc = [image[frame_idx] for frame_idx in
                          range(1, num_frames_per_video)]
      labels_for_preproc = [label[frame_idx] for frame_idx in
                            range(1, num_frames_per_video)]
      original_image_rest, image_rest, label_rest = (
          input_preprocess.preprocess_images_and_labels_consistently(
              imgs_for_preproc,
              labels_for_preproc,
              crop_height=crop_size[0] if crop_size is not None else None,
              crop_width=crop_size[1] if crop_size is not None else None,
              min_resize_value=min_resize_value,
              max_resize_value=max_resize_value,
              resize_factor=resize_factor,
              min_scale_factor=min_scale_factor,
              max_scale_factor=max_scale_factor,
              scale_factor_step_size=scale_factor_step_size,
              ignore_label=dataset.ignore_label,
              is_training=is_training,
              model_variant=model_variant))
      original_images.extend(original_image_rest)
      images.extend(image_rest)
      labels.extend(label_rest)
    assert len(original_images) == num_frames_per_video
    assert len(images) == num_frames_per_video
    assert len(labels) == num_frames_per_video

    if remap_labels_to_reference_frame:
      # Remap labels to indices into the labels of the (downscaled) reference
      # frame, or 0, i.e. background, for labels which are not present
      # in the reference.
      reference_labels = labels[0][tf.newaxis]
      h, w = train_utils.resolve_shape(reference_labels)[1:3]
      embedding_height = model.scale_dimension(
          h, 1.0 / decoder_output_stride)
      embedding_width = model.scale_dimension(
          w, 1.0 / decoder_output_stride)
      reference_labels_embedding_size = tf.squeeze(
          tf.image.resize_nearest_neighbor(
              reference_labels, tf.stack([embedding_height, embedding_width]),
              align_corners=True),
          axis=0)
      # Get sorted unique labels in the reference frame.
      labels_in_ref_frame, _ = tf.unique(
          tf.reshape(reference_labels_embedding_size, [-1]))
      labels_in_ref_frame = tf.contrib.framework.sort(labels_in_ref_frame)
      for idx in range(1, len(labels)):
        ref_label_mask = tf.equal(
            labels[idx],
            labels_in_ref_frame[tf.newaxis, tf.newaxis, :])
        remapped = tf.argmax(tf.cast(ref_label_mask, tf.uint8), axis=-1,
                             output_type=tf.int32)
        # Set to 0 if label is not present
        is_in_ref = tf.reduce_any(ref_label_mask, axis=-1)
        remapped *= tf.cast(is_in_ref, tf.int32)
        labels[idx] = remapped[..., tf.newaxis]

    if sample_adjacent_and_consistent_query_frames:
      if first_frame_finetuning and generate_prev_frame_mask_by_mask_damaging:
        preceding_frame_label = mask_damaging.damage_masks(labels[1])
      elif add_prev_frame_label:
        # Discard the image of the additional frame and take the label as
        # initialization for softmax feedback.
        original_images = [original_images[0]] + original_images[2:]
        preceding_frame_label = labels[1]
        images = [images[0]] + images[2:]
        labels = [labels[0]] + labels[2:]
        num_frames_per_video -= 1

    original_image = tf.stack(original_images, axis=0)
    image = tf.stack(images, axis=0)
    label = tf.stack(labels, axis=0)
  else:
    if label is not None:
      # Need to set label shape due to batching.
      label.set_shape([num_frames_per_video,
                       None if crop_size is None else crop_size[0],
                       None if crop_size is None else crop_size[1],
                       1])
    original_image = tf.to_float(tf.zeros_like(label))
    if crop_size is None:
      height = tf.shape(image)[1]
      width = tf.shape(image)[2]
    else:
      height = crop_size[0]
      width = crop_size[1]

  sample = {'image': image,
            'image_name': image_name,
            'height': height,
            'width': width,
            'video_id': video_id}
  if label is not None:
    sample['label'] = label

  if object_label is not None:
    sample['object_label'] = object_label

  if preceding_frame_label is not None:
    sample['preceding_frame_label'] = preceding_frame_label

  if not is_training:
    # Original image is only used during visualization.
    sample['original_image'] = original_image

  if is_training:
    if first_frame_finetuning:
      keep_input = tf.constant(True)
    else:
      keep_input = tf.logical_and(sampling_is_valid, tf.logical_and(
          _has_enough_pixels_of_each_object_in_first_frame(
              label, decoder_output_stride),
          _has_foreground_and_background_in_first_frame_2(
              label, decoder_output_stride)))

    batched = tf.train.maybe_batch(sample,
                                   keep_input=keep_input,
                                   batch_size=batch_size,
                                   num_threads=num_threads,
                                   capacity=batch_capacity_factor * batch_size,
                                   dynamic_pad=True)
  else:
    batched = tf.train.batch(sample,
                             batch_size=batch_size,
                             num_threads=num_threads,
                             capacity=batch_capacity_factor * batch_size,
                             dynamic_pad=True)

  # Flatten from [batch, num_frames_per_video, ...] to
  # batch * num_frames_per_video, ...].
  cropped_height = train_utils.resolve_shape(batched['image'])[2]
  cropped_width = train_utils.resolve_shape(batched['image'])[3]
  if num_frames_per_video is None:
    first_dim = -1
  else:
    first_dim = batch_size * num_frames_per_video
  batched['image'] = tf.reshape(batched['image'],
                                [first_dim, cropped_height, cropped_width, 3])
  if label is not None:
    batched['label'] = tf.reshape(batched['label'],
                                  [first_dim, cropped_height, cropped_width, 1])
  return batched
Ejemplo n.º 10
0
def _build_deeplab(inputs_queue_or_samples, outputs_to_num_classes,
                   ignore_label):
  """Builds a clone of DeepLab.

  Args:
    inputs_queue_or_samples: A prefetch queue for images and labels, or
      directly a dict of the samples.
    outputs_to_num_classes: A map from output type to the number of classes.
      For example, for the task of semantic segmentation with 21 semantic
      classes, we would have outputs_to_num_classes['semantic'] = 21.
    ignore_label: Ignore label.

  Returns:
    A map of maps from output_type (e.g., semantic prediction) to a
      dictionary of multi-scale logits names to logits. For each output_type,
      the dictionary has keys which correspond to the scales and values which
      correspond to the logits. For example, if `scales` equals [1.0, 1.5],
      then the keys would include 'merged_logits', 'logits_1.00' and
      'logits_1.50'.

  Raises:
    ValueError: If classification_loss is not softmax, softmax_with_attention,
      or triplet.
  """
  if hasattr(inputs_queue_or_samples, 'dequeue'):
    samples = inputs_queue_or_samples.dequeue()
  else:
    samples = inputs_queue_or_samples
  train_crop_size = (None if 0 in FLAGS.train_crop_size else
                     FLAGS.train_crop_size)

  model_options = common.VideoModelOptions(
      outputs_to_num_classes=outputs_to_num_classes,
      crop_size=train_crop_size,
      atrous_rates=FLAGS.atrous_rates,
      output_stride=FLAGS.output_stride)

  if model_options.classification_loss == 'softmax_with_attention':
    clone_batch_size = FLAGS.train_batch_size // FLAGS.num_clones

    # Create summaries of ground truth labels.
    for n in range(clone_batch_size):
      tf.summary.image(
          'gt_label_%d' % n,
          tf.cast(samples[common.LABEL][
              n * FLAGS.train_num_frames_per_video:
              (n + 1) * FLAGS.train_num_frames_per_video],
                  tf.uint8) * 32, max_outputs=FLAGS.train_num_frames_per_video)

    if common.PRECEDING_FRAME_LABEL in samples:
      preceding_frame_label = samples[common.PRECEDING_FRAME_LABEL]
      init_softmax = []
      for n in range(clone_batch_size):
        init_softmax_n = embedding_utils.create_initial_softmax_from_labels(
            preceding_frame_label[n, tf.newaxis],
            samples[common.LABEL][n * FLAGS.train_num_frames_per_video,
                                  tf.newaxis],
            common.parse_decoder_output_stride(),
            reduce_labels=True)
        init_softmax_n = tf.squeeze(init_softmax_n, axis=0)
        init_softmax.append(init_softmax_n)
        tf.summary.image('preceding_frame_label',
                         tf.cast(preceding_frame_label[n, tf.newaxis],
                                 tf.uint8) * 32)
    else:
      init_softmax = None

    outputs_to_scales_to_logits = (
        model.multi_scale_logits_with_nearest_neighbor_matching(
            samples[common.IMAGE],
            model_options=model_options,
            image_pyramid=FLAGS.image_pyramid,
            weight_decay=FLAGS.weight_decay,
            is_training=True,
            fine_tune_batch_norm=FLAGS.fine_tune_batch_norm,
            reference_labels=samples[common.LABEL],
            clone_batch_size=FLAGS.train_batch_size // FLAGS.num_clones,
            num_frames_per_video=FLAGS.train_num_frames_per_video,
            embedding_dimension=FLAGS.embedding_dimension,
            max_neighbors_per_object=FLAGS.train_max_neighbors_per_object,
            k_nearest_neighbors=FLAGS.k_nearest_neighbors,
            use_softmax_feedback=FLAGS.use_softmax_feedback,
            initial_softmax_feedback=init_softmax,
            embedding_seg_feature_dimension=
            FLAGS.embedding_seg_feature_dimension,
            embedding_seg_n_layers=FLAGS.embedding_seg_n_layers,
            embedding_seg_kernel_size=FLAGS.embedding_seg_kernel_size,
            embedding_seg_atrous_rates=FLAGS.embedding_seg_atrous_rates,
            normalize_nearest_neighbor_distances=
            FLAGS.normalize_nearest_neighbor_distances,
            also_attend_to_previous_frame=FLAGS.also_attend_to_previous_frame,
            damage_initial_previous_frame_mask=
            FLAGS.damage_initial_previous_frame_mask,
            use_local_previous_frame_attention=
            FLAGS.use_local_previous_frame_attention,
            previous_frame_attention_window_size=
            FLAGS.previous_frame_attention_window_size,
            use_first_frame_matching=FLAGS.use_first_frame_matching
        ))
  else:
    outputs_to_scales_to_logits = model.multi_scale_logits_v2(
        samples[common.IMAGE],
        model_options=model_options,
        image_pyramid=FLAGS.image_pyramid,
        weight_decay=FLAGS.weight_decay,
        is_training=True,
        fine_tune_batch_norm=FLAGS.fine_tune_batch_norm)

  if model_options.classification_loss == 'softmax':
    for output, num_classes in six.iteritems(outputs_to_num_classes):
      train_utils.add_softmax_cross_entropy_loss_for_each_scale(
          outputs_to_scales_to_logits[output],
          samples[common.LABEL],
          num_classes,
          ignore_label,
          loss_weight=1.0,
          upsample_logits=FLAGS.upsample_logits,
          scope=output)
  elif model_options.classification_loss == 'triplet':
    for output, _ in six.iteritems(outputs_to_num_classes):
      train_utils.add_triplet_loss_for_each_scale(
          FLAGS.train_batch_size // FLAGS.num_clones,
          FLAGS.train_num_frames_per_video,
          FLAGS.embedding_dimension, outputs_to_scales_to_logits[output],
          samples[common.LABEL], scope=output)
  elif model_options.classification_loss == 'softmax_with_attention':
    labels = samples[common.LABEL]
    batch_size = FLAGS.train_batch_size // FLAGS.num_clones
    num_frames_per_video = FLAGS.train_num_frames_per_video
    h, w = train_utils.resolve_shape(labels)[1:3]
    labels = tf.reshape(labels, tf.stack(
        [batch_size, num_frames_per_video, h, w, 1]))
    # Strip the reference labels off.
    if FLAGS.also_attend_to_previous_frame or FLAGS.use_softmax_feedback:
      n_ref_frames = 2
    else:
      n_ref_frames = 1
    labels = labels[:, n_ref_frames:]
    # Merge batch and time dimensions.
    labels = tf.reshape(labels, tf.stack(
        [batch_size * (num_frames_per_video - n_ref_frames), h, w, 1]))

    for output, num_classes in six.iteritems(outputs_to_num_classes):
      train_utils.add_dynamic_softmax_cross_entropy_loss_for_each_scale(
          outputs_to_scales_to_logits[output],
          labels,
          ignore_label,
          loss_weight=1.0,
          upsample_logits=FLAGS.upsample_logits,
          scope=output,
          top_k_percent_pixels=FLAGS.top_k_percent_pixels,
          hard_example_mining_step=FLAGS.hard_example_mining_step)
  else:
    raise ValueError('Only support softmax, softmax_with_attention'
                     ' or triplet for classification_loss.')

  return outputs_to_scales_to_logits
Ejemplo n.º 11
0
def get(dataset,
        num_frames_per_video,
        crop_size,
        batch_size,
        min_resize_value=None,
        max_resize_value=None,
        resize_factor=None,
        min_scale_factor=1.,
        max_scale_factor=1.,
        scale_factor_step_size=0,
        preprocess_image_and_label=True,
        num_readers=1,
        num_threads=1,
        dataset_split=None,
        is_training=True,
        model_variant=None,
        batch_capacity_factor=32,
        video_frames_are_decoded=False,
        decoder_output_stride=None,
        first_frame_finetuning=False,
        sample_only_first_frame_for_finetuning=False,
        sample_adjacent_and_consistent_query_frames=False,
        remap_labels_to_reference_frame=True,
        generate_prev_frame_mask_by_mask_damaging=False,
        three_frame_dataset=False,
        add_prev_frame_label=True):
    """Gets the dataset split for semantic segmentation.

  This functions gets the dataset split for semantic segmentation. In
  particular, it is a wrapper of (1) dataset_data_provider which returns the raw
  dataset split, (2) input_preprcess which preprocess the raw data, and (3) the
  Tensorflow operation of batching the preprocessed data. Then, the output could
  be directly used by training, evaluation or visualization.

  Args:
    dataset: An instance of slim Dataset.
    num_frames_per_video: The number of frames used per video
    crop_size: Image crop size [height, width].
    batch_size: Batch size.
    min_resize_value: Desired size of the smaller image side.
    max_resize_value: Maximum allowed size of the larger image side.
    resize_factor: Resized dimensions are multiple of factor plus one.
    min_scale_factor: Minimum scale factor value.
    max_scale_factor: Maximum scale factor value.
    scale_factor_step_size: The step size from min scale factor to max scale
      factor. The input is randomly scaled based on the value of
      (min_scale_factor, max_scale_factor, scale_factor_step_size).
    preprocess_image_and_label: Boolean variable specifies if preprocessing of
      image and label will be performed or not.
    num_readers: Number of readers for data provider.
    num_threads: Number of threads for batching data.
    dataset_split: Dataset split.
    is_training: Is training or not.
    model_variant: Model variant (string) for choosing how to mean-subtract the
      images. See feature_extractor.network_map for supported model variants.
    batch_capacity_factor: Batch capacity factor affecting the training queue
      batch capacity.
    video_frames_are_decoded: Boolean, whether the video frames are already
        decoded
    decoder_output_stride: Integer, the stride of the decoder output.
    first_frame_finetuning: Boolean, whether to only sample the first frame
      for fine-tuning.
    sample_only_first_frame_for_finetuning: Boolean, whether to only sample the
      first frame during fine-tuning. This should be False when using lucid or
      wonderland data, but true when fine-tuning on the first frame only.
      Only has an effect if first_frame_finetuning is True.
    sample_adjacent_and_consistent_query_frames: Boolean, if true, the query
      frames (all but the first frame which is the reference frame) will be
      sampled such that they are adjacent video frames and have the same
      crop coordinates and flip augmentation.
    remap_labels_to_reference_frame: Boolean, whether to remap the labels of
      the query frames to match the labels of the (downscaled) reference frame.
      If a query frame contains a label which is not present in the reference,
      it will be mapped to background.
    generate_prev_frame_mask_by_mask_damaging: Boolean, whether to generate
      the masks used as guidance from the previous frame by damaging the
      ground truth mask.
    three_frame_dataset: Boolean, whether the dataset has exactly three frames
      per video of which the first is to be used as reference and the two
      others are consecutive frames to be used as query frames.
    add_prev_frame_label: Boolean, whether to sample one more frame before the
      first query frame to obtain a previous frame label. Only has an effect,
      if sample_adjacent_and_consistent_query_frames is True and
      generate_prev_frame_mask_by_mask_damaging is False.

  Returns:
    A dictionary of batched Tensors for semantic segmentation.

  Raises:
    ValueError: dataset_split is None, or Failed to find labels.
  """
    if dataset_split is None:
        raise ValueError('Unknown dataset split.')
    if model_variant is None:
        tf.logging.warning('Please specify a model_variant. See '
                           'feature_extractor.network_map for supported model '
                           'variants.')

    data_provider = dataset_data_provider.DatasetDataProvider(
        dataset,
        num_readers=num_readers,
        num_epochs=None if is_training else 1,
        shuffle=is_training)
    image, label, object_label, image_name, height, width, video_id = _get_data(
        data_provider, dataset_split, video_frames_are_decoded)

    sampling_is_valid = tf.constant(True)
    if num_frames_per_video is not None:
        total_num_frames = tf.shape(image)[0]
        if first_frame_finetuning or three_frame_dataset:
            if sample_only_first_frame_for_finetuning:
                assert not sample_adjacent_and_consistent_query_frames, (
                    'this option does not make sense for sampling only first frame.'
                )
                # Sample the first frame num_frames_per_video times.
                sel_indices = tf.tile(tf.constant(0,
                                                  dtype=tf.int32)[tf.newaxis],
                                      multiples=[num_frames_per_video])
            else:
                if sample_adjacent_and_consistent_query_frames:
                    if add_prev_frame_label:
                        num_frames_per_video += 1
                    # Since this is first frame fine-tuning, we'll for now assume that
                    # each sequence has exactly 3 images: the ref frame and 2 adjacent
                    # query frames.
                    assert num_frames_per_video == 3
                    with tf.control_dependencies(
                        [tf.assert_equal(total_num_frames, 3)]):
                        sel_indices = tf.constant([1, 2], dtype=tf.int32)
                else:
                    # Sample num_frames_per_video - 1 query frames which are not the
                    # first frame.
                    sel_indices = tf.random_shuffle(
                        tf.range(1, total_num_frames))[:(num_frames_per_video -
                                                         1)]
                # Concat first frame as reference frame to the front.
                sel_indices = tf.concat(
                    [tf.constant(0, dtype=tf.int32)[tf.newaxis], sel_indices],
                    axis=0)
        else:
            if sample_adjacent_and_consistent_query_frames:
                if add_prev_frame_label:
                    # Sample one more frame which we can use to provide initial softmax
                    # feedback.
                    num_frames_per_video += 1
                ref_idx = tf.random_shuffle(tf.range(total_num_frames))[0]
                sampling_is_valid = tf.greater_equal(total_num_frames,
                                                     num_frames_per_video)

                def sample_query_start_idx():
                    return tf.random_shuffle(
                        tf.range(total_num_frames - num_frames_per_video +
                                 1))[0]

                query_start_idx = tf.cond(
                    sampling_is_valid, sample_query_start_idx,
                    lambda: tf.constant(0, dtype=tf.int32))

                def sample_sel_indices():
                    return tf.concat([
                        ref_idx[tf.newaxis],
                        tf.range(query_start_idx, query_start_idx +
                                 (num_frames_per_video - 1))
                    ],
                                     axis=0)

                sel_indices = tf.cond(
                    sampling_is_valid, sample_sel_indices, lambda: tf.zeros(
                        (num_frames_per_video, ), dtype=tf.int32))
            else:
                # Randomly sample some frames from the video.
                sel_indices = tf.random_shuffle(
                    tf.range(total_num_frames))[:num_frames_per_video]
        image = tf.gather(image, sel_indices, axis=0)
    if not video_frames_are_decoded:
        image = decode_image_sequence(image)

    if label is not None:
        if num_frames_per_video is not None:
            label = tf.gather(label, sel_indices, axis=0)
        if not video_frames_are_decoded:
            label = decode_image_sequence(label,
                                          image_format='png',
                                          channels=1)

        # Sometimes, label is saved as [num_frames_per_video, height, width] or
        # [num_frames_per_video, height, width, 1]. We change it to be
        # [num_frames_per_video, height, width, 1].
        if label.shape.ndims == 3:
            label = tf.expand_dims(label, 3)
        elif label.shape.ndims == 4 and label.shape.dims[3] == 1:
            pass
        else:
            raise ValueError('Input label shape must be '
                             '[num_frames_per_video, height, width],'
                             ' or [num_frames, height, width, 1]. '
                             'Got {}'.format(label.shape.ndims))
        label.set_shape([None, None, None, 1])

    # Add size of first dimension since tf can't figure it out automatically.
    image.set_shape((num_frames_per_video, None, None, None))
    if label is not None:
        label.set_shape((num_frames_per_video, None, None, None))

    preceding_frame_label = None
    if preprocess_image_and_label:
        if num_frames_per_video is None:
            raise ValueError(
                'num_frame_per_video must be specified for preproc.')
        original_images = []
        images = []
        labels = []
        if sample_adjacent_and_consistent_query_frames:
            num_frames_individual_preproc = 1
        else:
            num_frames_individual_preproc = num_frames_per_video
        for frame_idx in range(num_frames_individual_preproc):
            original_image_t, image_t, label_t = (
                input_preprocess.preprocess_image_and_label(
                    image[frame_idx],
                    label[frame_idx],
                    crop_height=crop_size[0]
                    if crop_size is not None else None,
                    crop_width=crop_size[1] if crop_size is not None else None,
                    min_resize_value=min_resize_value,
                    max_resize_value=max_resize_value,
                    resize_factor=resize_factor,
                    min_scale_factor=min_scale_factor,
                    max_scale_factor=max_scale_factor,
                    scale_factor_step_size=scale_factor_step_size,
                    ignore_label=dataset.ignore_label,
                    is_training=is_training,
                    model_variant=model_variant))
            original_images.append(original_image_t)
            images.append(image_t)
            labels.append(label_t)
        if sample_adjacent_and_consistent_query_frames:
            imgs_for_preproc = [
                image[frame_idx]
                for frame_idx in range(1, num_frames_per_video)
            ]
            labels_for_preproc = [
                label[frame_idx]
                for frame_idx in range(1, num_frames_per_video)
            ]
            original_image_rest, image_rest, label_rest = (
                input_preprocess.preprocess_images_and_labels_consistently(
                    imgs_for_preproc,
                    labels_for_preproc,
                    crop_height=crop_size[0]
                    if crop_size is not None else None,
                    crop_width=crop_size[1] if crop_size is not None else None,
                    min_resize_value=min_resize_value,
                    max_resize_value=max_resize_value,
                    resize_factor=resize_factor,
                    min_scale_factor=min_scale_factor,
                    max_scale_factor=max_scale_factor,
                    scale_factor_step_size=scale_factor_step_size,
                    ignore_label=dataset.ignore_label,
                    is_training=is_training,
                    model_variant=model_variant))
            original_images.extend(original_image_rest)
            images.extend(image_rest)
            labels.extend(label_rest)
        assert len(original_images) == num_frames_per_video
        assert len(images) == num_frames_per_video
        assert len(labels) == num_frames_per_video

        if remap_labels_to_reference_frame:
            # Remap labels to indices into the labels of the (downscaled) reference
            # frame, or 0, i.e. background, for labels which are not present
            # in the reference.
            reference_labels = labels[0][tf.newaxis]
            h, w = train_utils.resolve_shape(reference_labels)[1:3]
            embedding_height = model.scale_dimension(
                h, 1.0 / decoder_output_stride)
            embedding_width = model.scale_dimension(
                w, 1.0 / decoder_output_stride)
            reference_labels_embedding_size = tf.squeeze(
                tf.image.resize_nearest_neighbor(
                    reference_labels,
                    tf.stack([embedding_height, embedding_width]),
                    align_corners=True),
                axis=0)
            # Get sorted unique labels in the reference frame.
            labels_in_ref_frame, _ = tf.unique(
                tf.reshape(reference_labels_embedding_size, [-1]))
            labels_in_ref_frame = tf.contrib.framework.sort(
                labels_in_ref_frame)
            for idx in range(1, len(labels)):
                ref_label_mask = tf.equal(
                    labels[idx], labels_in_ref_frame[tf.newaxis,
                                                     tf.newaxis, :])
                remapped = tf.argmax(tf.cast(ref_label_mask, tf.uint8),
                                     axis=-1,
                                     output_type=tf.int32)
                # Set to 0 if label is not present
                is_in_ref = tf.reduce_any(ref_label_mask, axis=-1)
                remapped *= tf.cast(is_in_ref, tf.int32)
                labels[idx] = remapped[..., tf.newaxis]

        if sample_adjacent_and_consistent_query_frames:
            if first_frame_finetuning and generate_prev_frame_mask_by_mask_damaging:
                preceding_frame_label = mask_damaging.damage_masks(labels[1])
            elif add_prev_frame_label:
                # Discard the image of the additional frame and take the label as
                # initialization for softmax feedback.
                original_images = [original_images[0]] + original_images[2:]
                preceding_frame_label = labels[1]
                images = [images[0]] + images[2:]
                labels = [labels[0]] + labels[2:]
                num_frames_per_video -= 1

        original_image = tf.stack(original_images, axis=0)
        image = tf.stack(images, axis=0)
        label = tf.stack(labels, axis=0)
    else:
        if label is not None:
            # Need to set label shape due to batching.
            label.set_shape([
                num_frames_per_video,
                None if crop_size is None else crop_size[0],
                None if crop_size is None else crop_size[1], 1
            ])
        original_image = tf.to_float(tf.zeros_like(label))
        if crop_size is None:
            height = tf.shape(image)[1]
            width = tf.shape(image)[2]
        else:
            height = crop_size[0]
            width = crop_size[1]

    sample = {
        'image': image,
        'image_name': image_name,
        'height': height,
        'width': width,
        'video_id': video_id
    }
    if label is not None:
        sample['label'] = label

    if object_label is not None:
        sample['object_label'] = object_label

    if preceding_frame_label is not None:
        sample['preceding_frame_label'] = preceding_frame_label

    if not is_training:
        # Original image is only used during visualization.
        sample['original_image'] = original_image

    if is_training:
        if first_frame_finetuning:
            keep_input = tf.constant(True)
        else:
            keep_input = tf.logical_and(
                sampling_is_valid,
                tf.logical_and(
                    _has_enough_pixels_of_each_object_in_first_frame(
                        label, decoder_output_stride),
                    _has_foreground_and_background_in_first_frame_2(
                        label, decoder_output_stride)))

        batched = tf.train.maybe_batch(sample,
                                       keep_input=keep_input,
                                       batch_size=batch_size,
                                       num_threads=num_threads,
                                       capacity=batch_capacity_factor *
                                       batch_size,
                                       dynamic_pad=True)
    else:
        batched = tf.train.batch(sample,
                                 batch_size=batch_size,
                                 num_threads=num_threads,
                                 capacity=batch_capacity_factor * batch_size,
                                 dynamic_pad=True)

    # Flatten from [batch, num_frames_per_video, ...] to
    # batch * num_frames_per_video, ...].
    cropped_height = train_utils.resolve_shape(batched['image'])[2]
    cropped_width = train_utils.resolve_shape(batched['image'])[3]
    if num_frames_per_video is None:
        first_dim = -1
    else:
        first_dim = batch_size * num_frames_per_video
    batched['image'] = tf.reshape(
        batched['image'], [first_dim, cropped_height, cropped_width, 3])
    if label is not None:
        batched['label'] = tf.reshape(
            batched['label'], [first_dim, cropped_height, cropped_width, 1])
    return batched
Ejemplo n.º 12
0
def predict_labels(images,
                   model_options,
                   image_pyramid=None,
                   reference_labels=None,
                   k_nearest_neighbors=1,
                   embedding_dimension=None,
                   use_softmax_feedback=False,
                   initial_softmax_feedback=None,
                   embedding_seg_feature_dimension=256,
                   embedding_seg_n_layers=4,
                   embedding_seg_kernel_size=7,
                   embedding_seg_atrous_rates=None,
                   also_return_softmax_probabilities=False,
                   num_frames_per_video=None,
                   normalize_nearest_neighbor_distances=False,
                   also_attend_to_previous_frame=False,
                   use_local_previous_frame_attention=False,
                   previous_frame_attention_window_size=9,
                   use_first_frame_matching=True,
                   also_return_embeddings=False,
                   ref_embeddings=None):
  """Predicts segmentation labels.

  Args:
    images: A tensor of size [batch, height, width, channels].
    model_options: An InternalModelOptions instance to configure models.
    image_pyramid: Input image scales for multi-scale feature extraction.
    reference_labels: A tensor of size [batch, height, width, 1].
      ground truth labels used to perform a nearest neighbor query
    k_nearest_neighbors: Integer, the number of neighbors to use for nearest
      neighbor queries.
    embedding_dimension: Integer, the dimension used for the learned embedding.
    use_softmax_feedback: Boolean, whether to give the softmax predictions of
      the last frame as additional input to the segmentation head.
    initial_softmax_feedback: Float32 tensor, or None. Can be used to
      initialize the softmax predictions used for the feedback loop.
      Typically only useful for inference. Only has an effect if
      use_softmax_feedback is True.
    embedding_seg_feature_dimension: Integer, the dimensionality used in the
      segmentation head layers.
    embedding_seg_n_layers: Integer, the number of layers in the segmentation
      head.
    embedding_seg_kernel_size: Integer, the kernel size used in the
      segmentation head.
    embedding_seg_atrous_rates: List of integers of length
      embedding_seg_n_layers, the atrous rates to use for the segmentation head.
    also_return_softmax_probabilities: Boolean, if true, additionally return
      the softmax probabilities as second return value.
    num_frames_per_video: Integer, the number of frames per video.
    normalize_nearest_neighbor_distances: Boolean, whether to normalize the
      nearest neighbor distances to [0,1] using sigmoid, scale and shift.
    also_attend_to_previous_frame: Boolean, whether to also use nearest
      neighbor attention with respect to the previous frame.
    use_local_previous_frame_attention: Boolean, whether to restrict the
      previous frame attention to a local search window.
      Only has an effect, if also_attend_to_previous_frame is True.
    previous_frame_attention_window_size: Integer, the window size used for
      local previous frame attention, if use_local_previous_frame_attention
      is True.
    use_first_frame_matching: Boolean, whether to extract features by matching
      to the reference frame. This should always be true except for ablation
      experiments.
    also_return_embeddings: Boolean, whether to return the embeddings as well.
    ref_embeddings: Tuple of
      (first_frame_embeddings, previous_frame_embeddings),
      each of shape [batch, height, width, embedding_dimension], or None.

  Returns:
    A dictionary with keys specifying the output_type (e.g., semantic
      prediction) and values storing Tensors representing predictions (argmax
      over channels). Each prediction has size [batch, height, width].
    If also_return_softmax_probabilities is True, the second return value are
      the softmax probabilities.
    If also_return_embeddings is True, it will also return an embeddings
      tensor of shape [batch, height, width, embedding_dimension].

  Raises:
    ValueError: If classification_loss is not softmax, softmax_with_attention,
      nor triplet.
  """
  if (model_options.classification_loss == 'triplet' and
      reference_labels is None):
    raise ValueError('Need reference_labels for triplet loss')

  if model_options.classification_loss == 'softmax_with_attention':
    if embedding_dimension is None:
      raise ValueError('Need embedding_dimension for softmax_with_attention '
                       'loss')
    if reference_labels is None:
      raise ValueError('Need reference_labels for softmax_with_attention loss')
    res = (
        multi_scale_logits_with_nearest_neighbor_matching(
            images,
            model_options=model_options,
            image_pyramid=image_pyramid,
            is_training=False,
            reference_labels=reference_labels,
            clone_batch_size=1,
            num_frames_per_video=num_frames_per_video,
            embedding_dimension=embedding_dimension,
            max_neighbors_per_object=0,
            k_nearest_neighbors=k_nearest_neighbors,
            use_softmax_feedback=use_softmax_feedback,
            initial_softmax_feedback=initial_softmax_feedback,
            embedding_seg_feature_dimension=embedding_seg_feature_dimension,
            embedding_seg_n_layers=embedding_seg_n_layers,
            embedding_seg_kernel_size=embedding_seg_kernel_size,
            embedding_seg_atrous_rates=embedding_seg_atrous_rates,
            normalize_nearest_neighbor_distances=
            normalize_nearest_neighbor_distances,
            also_attend_to_previous_frame=also_attend_to_previous_frame,
            use_local_previous_frame_attention=
            use_local_previous_frame_attention,
            previous_frame_attention_window_size=
            previous_frame_attention_window_size,
            use_first_frame_matching=use_first_frame_matching,
            also_return_embeddings=also_return_embeddings,
            ref_embeddings=ref_embeddings
        ))
    if also_return_embeddings:
      outputs_to_scales_to_logits, embeddings = res
    else:
      outputs_to_scales_to_logits = res
      embeddings = None
  else:
    outputs_to_scales_to_logits = multi_scale_logits_v2(
        images,
        model_options=model_options,
        image_pyramid=image_pyramid,
        is_training=False,
        fine_tune_batch_norm=False)

  predictions = {}
  for output in sorted(outputs_to_scales_to_logits):
    scales_to_logits = outputs_to_scales_to_logits[output]
    original_logits = scales_to_logits[MERGED_LOGITS_SCOPE]
    if isinstance(original_logits, list):
      assert len(original_logits) == 1
      original_logits = original_logits[0]
    logits = tf.image.resize_bilinear(original_logits, tf.shape(images)[1:3],
                                      align_corners=True)
    if model_options.classification_loss in ('softmax',
                                             'softmax_with_attention'):
      predictions[output] = tf.argmax(logits, 3)
    elif model_options.classification_loss == 'triplet':
      # to keep this fast, we do the nearest neighbor assignment on the
      # resolution at which the embedding is extracted and scale the result up
      # afterwards
      embeddings = original_logits
      reference_labels_logits_size = tf.squeeze(
          tf.image.resize_nearest_neighbor(
              reference_labels[tf.newaxis],
              train_utils.resolve_shape(embeddings)[1:3],
              align_corners=True), axis=0)
      nn_labels = embedding_utils.assign_labels_by_nearest_neighbors(
          embeddings[0], embeddings[1:], reference_labels_logits_size,
          k_nearest_neighbors)
      predictions[common.OUTPUT_TYPE] = tf.image.resize_nearest_neighbor(
          nn_labels, tf.shape(images)[1:3], align_corners=True)
    else:
      raise ValueError(
          'Only support softmax, triplet, or softmax_with_attention for '
          'classification_loss.')

  if also_return_embeddings:
    assert also_return_softmax_probabilities
    return predictions, tf.nn.softmax(original_logits, axis=-1), embeddings
  elif also_return_softmax_probabilities:
    return predictions, tf.nn.softmax(original_logits, axis=-1)
  else:
    return predictions