Exemplo n.º 1
0
def create_initial_softmax_from_labels(last_frame_labels, reference_labels,
                                       decoder_output_stride, reduce_labels):
    """Creates initial softmax predictions from last frame labels.

  Args:
    last_frame_labels: last frame labels of shape [1, height, width, 1].
    reference_labels: reference frame labels of shape [1, height, width, 1].
    decoder_output_stride: Integer, the stride of the decoder. Can be None, in
      this case it's assumed that the last_frame_labels and reference_labels
      are already scaled to the decoder output resolution.
    reduce_labels: Boolean, whether to reduce the depth of the softmax one_hot
      encoding to the actual number of labels present in the reference frame
      (otherwise the depth will be the highest label index + 1).

  Returns:
    init_softmax: the initial softmax predictions.
  """
    if decoder_output_stride is None:
        labels_output_size = last_frame_labels
        reference_labels_output_size = reference_labels
    else:
        h = tf.shape(last_frame_labels)[1]
        w = tf.shape(last_frame_labels)[2]
        h_sub = model.scale_dimension(h, 1.0 / decoder_output_stride)
        w_sub = model.scale_dimension(w, 1.0 / decoder_output_stride)
        labels_output_size = tf.image.resize_nearest_neighbor(
            last_frame_labels, [h_sub, w_sub], align_corners=True)
        reference_labels_output_size = tf.image.resize_nearest_neighbor(
            reference_labels, [h_sub, w_sub], align_corners=True)
    if reduce_labels:
        unique_labels, _ = tf.unique(
            tf.reshape(reference_labels_output_size, [-1]))
        depth = tf.size(unique_labels)
    else:
        depth = tf.reduce_max(reference_labels_output_size) + 1
    one_hot_assertion = tf.assert_less(tf.reduce_max(labels_output_size),
                                       depth)
    with tf.control_dependencies([one_hot_assertion]):
        init_softmax = tf.one_hot(tf.squeeze(labels_output_size, axis=-1),
                                  depth=depth,
                                  dtype=tf.float32)
    return init_softmax
Exemplo n.º 2
0
def create_initial_softmax_from_labels(last_frame_labels, reference_labels,
                                       decoder_output_stride, reduce_labels):
  """Creates initial softmax predictions from last frame labels.

  Args:
    last_frame_labels: last frame labels of shape [1, height, width, 1].
    reference_labels: reference frame labels of shape [1, height, width, 1].
    decoder_output_stride: Integer, the stride of the decoder. Can be None, in
      this case it's assumed that the last_frame_labels and reference_labels
      are already scaled to the decoder output resolution.
    reduce_labels: Boolean, whether to reduce the depth of the softmax one_hot
      encoding to the actual number of labels present in the reference frame
      (otherwise the depth will be the highest label index + 1).

  Returns:
    init_softmax: the initial softmax predictions.
  """
  if decoder_output_stride is None:
    labels_output_size = last_frame_labels
    reference_labels_output_size = reference_labels
  else:
    h = tf.shape(last_frame_labels)[1]
    w = tf.shape(last_frame_labels)[2]
    h_sub = model.scale_dimension(h, 1.0 / decoder_output_stride)
    w_sub = model.scale_dimension(w, 1.0 / decoder_output_stride)
    labels_output_size = tf.image.resize_nearest_neighbor(
        last_frame_labels, [h_sub, w_sub], align_corners=True)
    reference_labels_output_size = tf.image.resize_nearest_neighbor(
        reference_labels, [h_sub, w_sub], align_corners=True)
  if reduce_labels:
    unique_labels, _ = tf.unique(tf.reshape(reference_labels_output_size, [-1]))
    depth = tf.size(unique_labels)
  else:
    depth = tf.reduce_max(reference_labels_output_size) + 1
  one_hot_assertion = tf.assert_less(tf.reduce_max(labels_output_size), depth)
  with tf.control_dependencies([one_hot_assertion]):
    init_softmax = tf.one_hot(tf.squeeze(labels_output_size,
                                         axis=-1),
                              depth=depth,
                              dtype=tf.float32)
  return init_softmax
Exemplo n.º 3
0
def get_embeddings(images, model_options, embedding_dimension):
    """Extracts embedding vectors for images. Should only be used for inference.

  Args:
    images: A tensor of shape [batch, height, width, channels].
    model_options: A ModelOptions instance to configure models.
    embedding_dimension: Integer, the dimension of the embedding.

  Returns:
    embeddings: A tensor of shape [batch, height, width, embedding_dimension].
  """
    features, end_points = model.extract_features(images,
                                                  model_options,
                                                  is_training=False)

    if model_options.decoder_output_stride is not None:
        if model_options.crop_size is None:
            height = tf.shape(images)[1]
            width = tf.shape(images)[2]
        else:
            height, width = model_options.crop_size
        decoder_height = model.scale_dimension(
            height, 1.0 / model_options.decoder_output_stride)
        decoder_width = model.scale_dimension(
            width, 1.0 / model_options.decoder_output_stride)
        features = model.refine_by_decoder(
            features,
            end_points,
            decoder_height=decoder_height,
            decoder_width=decoder_width,
            decoder_use_separable_conv=model_options.
            decoder_use_separable_conv,
            model_variant=model_options.model_variant,
            is_training=False)

    with tf.variable_scope('embedding'):
        embeddings = split_separable_conv2d_with_identity_initializer(
            features, embedding_dimension, scope='split_separable_conv2d')
    return embeddings
Exemplo n.º 4
0
 def testScaleDimensionOutput(self):
     self.assertEqual(161, model.scale_dimension(321, 0.5))
     self.assertEqual(193, model.scale_dimension(321, 0.6))
     self.assertEqual(241, model.scale_dimension(321, 0.75))
Exemplo n.º 5
0
 def testScaleDimensionOutput(self):
   self.assertEqual(161, model.scale_dimension(321, 0.5))
   self.assertEqual(193, model.scale_dimension(321, 0.6))
   self.assertEqual(241, model.scale_dimension(321, 0.75))
Exemplo n.º 6
0
def get_logits_with_matching(images,
                             model_options,
                             weight_decay=0.0001,
                             reuse=None,
                             is_training=False,
                             fine_tune_batch_norm=False,
                             reference_labels=None,
                             batch_size=None,
                             num_frames_per_video=None,
                             embedding_dimension=None,
                             max_neighbors_per_object=0,
                             k_nearest_neighbors=1,
                             use_softmax_feedback=True,
                             initial_softmax_feedback=None,
                             embedding_seg_feature_dimension=256,
                             embedding_seg_n_layers=4,
                             embedding_seg_kernel_size=7,
                             embedding_seg_atrous_rates=None,
                             normalize_nearest_neighbor_distances=True,
                             also_attend_to_previous_frame=True,
                             damage_initial_previous_frame_mask=False,
                             use_local_previous_frame_attention=True,
                             previous_frame_attention_window_size=15,
                             use_first_frame_matching=True,
                             also_return_embeddings=False,
                             ref_embeddings=None):
  """Gets the logits by atrous/image spatial pyramid pooling using attention.

  Args:
    images: A tensor of size [batch, height, width, channels].
    model_options: A ModelOptions instance to configure models.
    weight_decay: The weight decay for model variables.
    reuse: Reuse the model variables or not.
    is_training: Is training or not.
    fine_tune_batch_norm: Fine-tune the batch norm parameters or not.
    reference_labels: The segmentation labels of the reference frame on which
      attention is applied.
    batch_size: Integer, the number of videos on a batch
    num_frames_per_video: Integer, the number of frames per video
    embedding_dimension: Integer, the dimension of the embedding
    max_neighbors_per_object: Integer, the maximum number of candidates
      for the nearest neighbor query per object after subsampling.
      Can be 0 for no subsampling.
    k_nearest_neighbors: Integer, the number of nearest neighbors to use.
    use_softmax_feedback: Boolean, whether to give the softmax predictions of
      the last frame as additional input to the segmentation head.
    initial_softmax_feedback: List of Float32 tensors, or None. Can be used to
      initialize the softmax predictions used for the feedback loop.
      Only has an effect if use_softmax_feedback is True.
    embedding_seg_feature_dimension: Integer, the dimensionality used in the
      segmentation head layers.
    embedding_seg_n_layers: Integer, the number of layers in the segmentation
      head.
    embedding_seg_kernel_size: Integer, the kernel size used in the
      segmentation head.
    embedding_seg_atrous_rates: List of integers of length
      embedding_seg_n_layers, the atrous rates to use for the segmentation head.
    normalize_nearest_neighbor_distances: Boolean, whether to normalize the
      nearest neighbor distances to [0,1] using sigmoid, scale and shift.
    also_attend_to_previous_frame: Boolean, whether to also use nearest
      neighbor attention with respect to the previous frame.
    damage_initial_previous_frame_mask: Boolean, whether to artificially damage
      the initial previous frame mask. Only has an effect if
      also_attend_to_previous_frame is True.
    use_local_previous_frame_attention: Boolean, whether to restrict the
      previous frame attention to a local search window.
      Only has an effect, if also_attend_to_previous_frame is True.
    previous_frame_attention_window_size: Integer, the window size used for
      local previous frame attention, if use_local_previous_frame_attention
      is True.
    use_first_frame_matching: Boolean, whether to extract features by matching
      to the reference frame. This should always be true except for ablation
      experiments.
    also_return_embeddings: Boolean, whether to return the embeddings as well.
    ref_embeddings: Tuple of
      (first_frame_embeddings, previous_frame_embeddings),
      each of shape [batch, height, width, embedding_dimension], or None.
  Returns:
    outputs_to_logits: A map from output_type to logits.
    If also_return_embeddings is True, it will also return an embeddings
      tensor of shape [batch, height, width, embedding_dimension].
  """
  features, end_points = model.extract_features(
      images,
      model_options,
      weight_decay=weight_decay,
      reuse=reuse,
      is_training=is_training,
      fine_tune_batch_norm=fine_tune_batch_norm)

  if model_options.decoder_output_stride:
    decoder_output_stride = min(model_options.decoder_output_stride)
    if model_options.crop_size is None:
      height = tf.shape(images)[1]
      width = tf.shape(images)[2]
    else:
      height, width = model_options.crop_size
    decoder_height = model.scale_dimension(height, 1.0 / decoder_output_stride)
    decoder_width = model.scale_dimension(width, 1.0 / decoder_output_stride)
    features = model.refine_by_decoder(
        features,
        end_points,
        crop_size=[height, width],
        decoder_output_stride=[decoder_output_stride],
        decoder_use_separable_conv=model_options.decoder_use_separable_conv,
        model_variant=model_options.model_variant,
        weight_decay=weight_decay,
        reuse=reuse,
        is_training=is_training,
        fine_tune_batch_norm=fine_tune_batch_norm)

  with tf.variable_scope('embedding', reuse=reuse):
    embeddings = split_separable_conv2d_with_identity_initializer(
        features, embedding_dimension, scope='split_separable_conv2d')
    embeddings = tf.identity(embeddings, name='embeddings')
  scaled_reference_labels = tf.image.resize_nearest_neighbor(
      reference_labels,
      resolve_shape(embeddings, 4)[1:3],
      align_corners=True)
  h, w = decoder_height, decoder_width
  if num_frames_per_video is None:
    num_frames_per_video = tf.size(embeddings) // (
        batch_size * h * w * embedding_dimension)
  new_labels_shape = tf.stack([batch_size, -1, h, w, 1])
  reshaped_reference_labels = tf.reshape(scaled_reference_labels,
                                         new_labels_shape)
  new_embeddings_shape = tf.stack([batch_size,
                                   num_frames_per_video, h, w,
                                   embedding_dimension])
  reshaped_embeddings = tf.reshape(embeddings, new_embeddings_shape)
  all_nn_features = []
  all_ref_obj_ids = []
  # To keep things simple, we do all this separate for each sequence for now.
  for n in range(batch_size):
    embedding = reshaped_embeddings[n]
    if ref_embeddings is None:
      n_chunks = 100
      reference_embedding = embedding[0]
      if also_attend_to_previous_frame or use_softmax_feedback:
        queries_embedding = embedding[2:]
      else:
        queries_embedding = embedding[1:]
    else:
      if USE_CORRELATION_COST:
        n_chunks = 20
      else:
        n_chunks = 500
      reference_embedding = ref_embeddings[0][n]
      queries_embedding = embedding
    reference_labels = reshaped_reference_labels[n][0]
    nn_features_n, ref_obj_ids = nearest_neighbor_features_per_object(
        reference_embedding, queries_embedding, reference_labels,
        max_neighbors_per_object, k_nearest_neighbors, n_chunks=n_chunks)
    if normalize_nearest_neighbor_distances:
      nn_features_n = (tf.nn.sigmoid(nn_features_n) - 0.5) * 2
    all_nn_features.append(nn_features_n)
    all_ref_obj_ids.append(ref_obj_ids)

  feat_dim = resolve_shape(features)[-1]
  features = tf.reshape(features, tf.stack(
      [batch_size, num_frames_per_video, h, w, feat_dim]))
  if ref_embeddings is None:
    # Strip the features for the reference frame.
    if also_attend_to_previous_frame or use_softmax_feedback:
      features = features[:, 2:]
    else:
      features = features[:, 1:]

  # To keep things simple, we do all this separate for each sequence for now.
  outputs_to_logits = {output: [] for
                       output in model_options.outputs_to_num_classes}
  for n in range(batch_size):
    features_n = features[n]
    nn_features_n = all_nn_features[n]
    nn_features_n_tr = tf.transpose(nn_features_n, [3, 0, 1, 2, 4])
    n_objs = tf.shape(nn_features_n_tr)[0]
    # Repeat features for every object.
    features_n_tiled = tf.tile(features_n[tf.newaxis],
                               multiples=[n_objs, 1, 1, 1, 1])
    prev_frame_labels = None
    if also_attend_to_previous_frame:
      prev_frame_labels = reshaped_reference_labels[n, 1]
      if is_training and damage_initial_previous_frame_mask:
        # Damage the previous frame masks.
        prev_frame_labels = mask_damaging.damage_masks(prev_frame_labels,
                                                       dilate=False)
      tf.summary.image('prev_frame_labels',
                       tf.cast(prev_frame_labels[tf.newaxis],
                               tf.uint8) * 32)
      initial_softmax_feedback_n = create_initial_softmax_from_labels(
          prev_frame_labels, reshaped_reference_labels[n][0],
          decoder_output_stride=None, reduce_labels=True)
    elif initial_softmax_feedback is not None:
      initial_softmax_feedback_n = initial_softmax_feedback[n]
    else:
      initial_softmax_feedback_n = None
    if initial_softmax_feedback_n is None:
      last_softmax = tf.zeros((n_objs, h, w, 1), dtype=tf.float32)
    else:
      last_softmax = tf.transpose(initial_softmax_feedback_n, [2, 0, 1])[
          ..., tf.newaxis]
    assert len(model_options.outputs_to_num_classes) == 1
    output = list(model_options.outputs_to_num_classes.keys())[0]
    logits = []
    n_ref_frames = 1
    prev_frame_nn_features_n = None
    if also_attend_to_previous_frame or use_softmax_feedback:
      n_ref_frames += 1
    if ref_embeddings is not None:
      n_ref_frames = 0
    for t in range(num_frames_per_video - n_ref_frames):
      to_concat = [features_n_tiled[:, t]]
      if use_first_frame_matching:
        to_concat.append(nn_features_n_tr[:, t])
      if use_softmax_feedback:
        to_concat.append(last_softmax)
      if also_attend_to_previous_frame:
        assert normalize_nearest_neighbor_distances, (
            'previous frame attention currently only works when normalized '
            'distances are used')
        embedding = reshaped_embeddings[n]
        if ref_embeddings is None:
          last_frame_embedding = embedding[t + 1]
          query_embeddings = embedding[t + 2, tf.newaxis]
        else:
          last_frame_embedding = ref_embeddings[1][0]
          query_embeddings = embedding
        if use_local_previous_frame_attention:
          assert query_embeddings.shape[0] == 1
          prev_frame_nn_features_n = (
              local_previous_frame_nearest_neighbor_features_per_object(
                  last_frame_embedding,
                  query_embeddings[0],
                  prev_frame_labels,
                  all_ref_obj_ids[n],
                  max_distance=previous_frame_attention_window_size)
          )
        else:
          prev_frame_nn_features_n, _ = (
              nearest_neighbor_features_per_object(
                  last_frame_embedding, query_embeddings, prev_frame_labels,
                  max_neighbors_per_object, k_nearest_neighbors,
                  gt_ids=all_ref_obj_ids[n]))
          prev_frame_nn_features_n = (tf.nn.sigmoid(
              prev_frame_nn_features_n) - 0.5) * 2
        prev_frame_nn_features_n_sq = tf.squeeze(prev_frame_nn_features_n,
                                                 axis=0)
        prev_frame_nn_features_n_tr = tf.transpose(
            prev_frame_nn_features_n_sq, [2, 0, 1, 3])
        to_concat.append(prev_frame_nn_features_n_tr)
      features_n_concat_t = tf.concat(to_concat, axis=-1)
      embedding_seg_features_n_t = (
          create_embedding_segmentation_features(
              features_n_concat_t, embedding_seg_feature_dimension,
              embedding_seg_n_layers, embedding_seg_kernel_size,
              reuse or n > 0, atrous_rates=embedding_seg_atrous_rates))
      logits_t = model.get_branch_logits(
          embedding_seg_features_n_t,
          1,
          model_options.atrous_rates,
          aspp_with_batch_norm=model_options.aspp_with_batch_norm,
          kernel_size=model_options.logits_kernel_size,
          weight_decay=weight_decay,
          reuse=reuse or n > 0 or t > 0,
          scope_suffix=output
      )
      logits.append(logits_t)
      prev_frame_labels = tf.transpose(tf.argmax(logits_t, axis=0),
                                       [2, 0, 1])
      last_softmax = tf.nn.softmax(logits_t, axis=0)
    logits = tf.stack(logits, axis=1)
    logits_shape = tf.stack(
        [n_objs, num_frames_per_video - n_ref_frames] +
        resolve_shape(logits)[2:-1])
    logits_reshaped = tf.reshape(logits, logits_shape)
    logits_transposed = tf.transpose(logits_reshaped, [1, 2, 3, 0])
    outputs_to_logits[output].append(logits_transposed)

    add_image_summaries(
        images[n * num_frames_per_video: (n+1) * num_frames_per_video],
        nn_features_n,
        logits_transposed,
        batch_size=1,
        prev_frame_nn_features=prev_frame_nn_features_n)
  if also_return_embeddings:
    return outputs_to_logits, embeddings
  else:
    return outputs_to_logits
Exemplo n.º 7
0
def get_logits_with_matching(images,
                             model_options,
                             weight_decay=0.0001,
                             reuse=None,
                             is_training=False,
                             fine_tune_batch_norm=False,
                             reference_labels=None,
                             batch_size=None,
                             num_frames_per_video=None,
                             embedding_dimension=None,
                             max_neighbors_per_object=0,
                             k_nearest_neighbors=1,
                             use_softmax_feedback=True,
                             initial_softmax_feedback=None,
                             embedding_seg_feature_dimension=256,
                             embedding_seg_n_layers=4,
                             embedding_seg_kernel_size=7,
                             embedding_seg_atrous_rates=None,
                             normalize_nearest_neighbor_distances=True,
                             also_attend_to_previous_frame=True,
                             damage_initial_previous_frame_mask=False,
                             use_local_previous_frame_attention=True,
                             previous_frame_attention_window_size=15,
                             use_first_frame_matching=True,
                             also_return_embeddings=False,
                             ref_embeddings=None):
  """Gets the logits by atrous/image spatial pyramid pooling using attention.

  Args:
    images: A tensor of size [batch, height, width, channels].
    model_options: A ModelOptions instance to configure models.
    weight_decay: The weight decay for model variables.
    reuse: Reuse the model variables or not.
    is_training: Is training or not.
    fine_tune_batch_norm: Fine-tune the batch norm parameters or not.
    reference_labels: The segmentation labels of the reference frame on which
      attention is applied.
    batch_size: Integer, the number of videos on a batch
    num_frames_per_video: Integer, the number of frames per video
    embedding_dimension: Integer, the dimension of the embedding
    max_neighbors_per_object: Integer, the maximum number of candidates
      for the nearest neighbor query per object after subsampling.
      Can be 0 for no subsampling.
    k_nearest_neighbors: Integer, the number of nearest neighbors to use.
    use_softmax_feedback: Boolean, whether to give the softmax predictions of
      the last frame as additional input to the segmentation head.
    initial_softmax_feedback: List of Float32 tensors, or None. Can be used to
      initialize the softmax predictions used for the feedback loop.
      Only has an effect if use_softmax_feedback is True.
    embedding_seg_feature_dimension: Integer, the dimensionality used in the
      segmentation head layers.
    embedding_seg_n_layers: Integer, the number of layers in the segmentation
      head.
    embedding_seg_kernel_size: Integer, the kernel size used in the
      segmentation head.
    embedding_seg_atrous_rates: List of integers of length
      embedding_seg_n_layers, the atrous rates to use for the segmentation head.
    normalize_nearest_neighbor_distances: Boolean, whether to normalize the
      nearest neighbor distances to [0,1] using sigmoid, scale and shift.
    also_attend_to_previous_frame: Boolean, whether to also use nearest
      neighbor attention with respect to the previous frame.
    damage_initial_previous_frame_mask: Boolean, whether to artificially damage
      the initial previous frame mask. Only has an effect if
      also_attend_to_previous_frame is True.
    use_local_previous_frame_attention: Boolean, whether to restrict the
      previous frame attention to a local search window.
      Only has an effect, if also_attend_to_previous_frame is True.
    previous_frame_attention_window_size: Integer, the window size used for
      local previous frame attention, if use_local_previous_frame_attention
      is True.
    use_first_frame_matching: Boolean, whether to extract features by matching
      to the reference frame. This should always be true except for ablation
      experiments.
    also_return_embeddings: Boolean, whether to return the embeddings as well.
    ref_embeddings: Tuple of
      (first_frame_embeddings, previous_frame_embeddings),
      each of shape [batch, height, width, embedding_dimension], or None.
  Returns:
    outputs_to_logits: A map from output_type to logits.
    If also_return_embeddings is True, it will also return an embeddings
      tensor of shape [batch, height, width, embedding_dimension].
  """
  features, end_points = model.extract_features(
      images,
      model_options,
      weight_decay=weight_decay,
      reuse=reuse,
      is_training=is_training,
      fine_tune_batch_norm=fine_tune_batch_norm)

  if model_options.decoder_output_stride:
    decoder_output_stride = min(model_options.decoder_output_stride)
    if model_options.crop_size is None:
      height = tf.shape(images)[1]
      width = tf.shape(images)[2]
    else:
      height, width = model_options.crop_size
    decoder_height = model.scale_dimension(height, 1.0 / decoder_output_stride)
    decoder_width = model.scale_dimension(width, 1.0 / decoder_output_stride)
    features = model.refine_by_decoder(
        features,
        end_points,
        crop_size=[height, width],
        decoder_output_stride=[decoder_output_stride],
        decoder_use_separable_conv=model_options.decoder_use_separable_conv,
        model_variant=model_options.model_variant,
        weight_decay=weight_decay,
        reuse=reuse,
        is_training=is_training,
        fine_tune_batch_norm=fine_tune_batch_norm)

  with tf.variable_scope('embedding', reuse=reuse):
    embeddings = split_separable_conv2d_with_identity_initializer(
        features, embedding_dimension, scope='split_separable_conv2d')
    embeddings = tf.identity(embeddings, name='embeddings')
  scaled_reference_labels = tf.image.resize_nearest_neighbor(
      reference_labels,
      resolve_shape(embeddings, 4)[1:3],
      align_corners=True)
  h, w = decoder_height, decoder_width
  if num_frames_per_video is None:
    num_frames_per_video = tf.size(embeddings) // (
        batch_size * h * w * embedding_dimension)
  new_labels_shape = tf.stack([batch_size, -1, h, w, 1])
  reshaped_reference_labels = tf.reshape(scaled_reference_labels,
                                         new_labels_shape)
  new_embeddings_shape = tf.stack([batch_size,
                                   num_frames_per_video, h, w,
                                   embedding_dimension])
  reshaped_embeddings = tf.reshape(embeddings, new_embeddings_shape)
  all_nn_features = []
  all_ref_obj_ids = []
  # To keep things simple, we do all this separate for each sequence for now.
  for n in range(batch_size):
    embedding = reshaped_embeddings[n]
    if ref_embeddings is None:
      n_chunks = 100
      reference_embedding = embedding[0]
      if also_attend_to_previous_frame or use_softmax_feedback:
        queries_embedding = embedding[2:]
      else:
        queries_embedding = embedding[1:]
    else:
      if USE_CORRELATION_COST:
        n_chunks = 20
      else:
        n_chunks = 500
      reference_embedding = ref_embeddings[0][n]
      queries_embedding = embedding
    reference_labels = reshaped_reference_labels[n][0]
    nn_features_n, ref_obj_ids = nearest_neighbor_features_per_object(
        reference_embedding, queries_embedding, reference_labels,
        max_neighbors_per_object, k_nearest_neighbors, n_chunks=n_chunks)
    if normalize_nearest_neighbor_distances:
      nn_features_n = (tf.nn.sigmoid(nn_features_n) - 0.5) * 2
    all_nn_features.append(nn_features_n)
    all_ref_obj_ids.append(ref_obj_ids)

  feat_dim = resolve_shape(features)[-1]
  features = tf.reshape(features, tf.stack(
      [batch_size, num_frames_per_video, h, w, feat_dim]))
  if ref_embeddings is None:
    # Strip the features for the reference frame.
    if also_attend_to_previous_frame or use_softmax_feedback:
      features = features[:, 2:]
    else:
      features = features[:, 1:]

  # To keep things simple, we do all this separate for each sequence for now.
  outputs_to_logits = {output: [] for
                       output in model_options.outputs_to_num_classes}
  for n in range(batch_size):
    features_n = features[n]
    nn_features_n = all_nn_features[n]
    nn_features_n_tr = tf.transpose(nn_features_n, [3, 0, 1, 2, 4])
    n_objs = tf.shape(nn_features_n_tr)[0]
    # Repeat features for every object.
    features_n_tiled = tf.tile(features_n[tf.newaxis],
                               multiples=[n_objs, 1, 1, 1, 1])
    prev_frame_labels = None
    if also_attend_to_previous_frame:
      prev_frame_labels = reshaped_reference_labels[n, 1]
      if is_training and damage_initial_previous_frame_mask:
        # Damage the previous frame masks.
        prev_frame_labels = mask_damaging.damage_masks(prev_frame_labels,
                                                       dilate=False)
      tf.summary.image('prev_frame_labels',
                       tf.cast(prev_frame_labels[tf.newaxis],
                               tf.uint8) * 32)
      initial_softmax_feedback_n = create_initial_softmax_from_labels(
          prev_frame_labels, reshaped_reference_labels[n][0],
          decoder_output_stride=None, reduce_labels=True)
    elif initial_softmax_feedback is not None:
      initial_softmax_feedback_n = initial_softmax_feedback[n]
    else:
      initial_softmax_feedback_n = None
    if initial_softmax_feedback_n is None:
      last_softmax = tf.zeros((n_objs, h, w, 1), dtype=tf.float32)
    else:
      last_softmax = tf.transpose(initial_softmax_feedback_n, [2, 0, 1])[
          ..., tf.newaxis]
    assert len(model_options.outputs_to_num_classes) == 1
    output = model_options.outputs_to_num_classes.keys()[0]
    logits = []
    n_ref_frames = 1
    prev_frame_nn_features_n = None
    if also_attend_to_previous_frame or use_softmax_feedback:
      n_ref_frames += 1
    if ref_embeddings is not None:
      n_ref_frames = 0
    for t in range(num_frames_per_video - n_ref_frames):
      to_concat = [features_n_tiled[:, t]]
      if use_first_frame_matching:
        to_concat.append(nn_features_n_tr[:, t])
      if use_softmax_feedback:
        to_concat.append(last_softmax)
      if also_attend_to_previous_frame:
        assert normalize_nearest_neighbor_distances, (
            'previous frame attention currently only works when normalized '
            'distances are used')
        embedding = reshaped_embeddings[n]
        if ref_embeddings is None:
          last_frame_embedding = embedding[t + 1]
          query_embeddings = embedding[t + 2, tf.newaxis]
        else:
          last_frame_embedding = ref_embeddings[1][0]
          query_embeddings = embedding
        if use_local_previous_frame_attention:
          assert query_embeddings.shape[0] == 1
          prev_frame_nn_features_n = (
              local_previous_frame_nearest_neighbor_features_per_object(
                  last_frame_embedding,
                  query_embeddings[0],
                  prev_frame_labels,
                  all_ref_obj_ids[n],
                  max_distance=previous_frame_attention_window_size)
          )
        else:
          prev_frame_nn_features_n, _ = (
              nearest_neighbor_features_per_object(
                  last_frame_embedding, query_embeddings, prev_frame_labels,
                  max_neighbors_per_object, k_nearest_neighbors,
                  gt_ids=all_ref_obj_ids[n]))
          prev_frame_nn_features_n = (tf.nn.sigmoid(
              prev_frame_nn_features_n) - 0.5) * 2
        prev_frame_nn_features_n_sq = tf.squeeze(prev_frame_nn_features_n,
                                                 axis=0)
        prev_frame_nn_features_n_tr = tf.transpose(
            prev_frame_nn_features_n_sq, [2, 0, 1, 3])
        to_concat.append(prev_frame_nn_features_n_tr)
      features_n_concat_t = tf.concat(to_concat, axis=-1)
      embedding_seg_features_n_t = (
          create_embedding_segmentation_features(
              features_n_concat_t, embedding_seg_feature_dimension,
              embedding_seg_n_layers, embedding_seg_kernel_size,
              reuse or n > 0, atrous_rates=embedding_seg_atrous_rates))
      logits_t = model.get_branch_logits(
          embedding_seg_features_n_t,
          1,
          model_options.atrous_rates,
          aspp_with_batch_norm=model_options.aspp_with_batch_norm,
          kernel_size=model_options.logits_kernel_size,
          weight_decay=weight_decay,
          reuse=reuse or n > 0 or t > 0,
          scope_suffix=output
      )
      logits.append(logits_t)
      prev_frame_labels = tf.transpose(tf.argmax(logits_t, axis=0),
                                       [2, 0, 1])
      last_softmax = tf.nn.softmax(logits_t, axis=0)
    logits = tf.stack(logits, axis=1)
    logits_shape = tf.stack(
        [n_objs, num_frames_per_video - n_ref_frames] +
        resolve_shape(logits)[2:-1])
    logits_reshaped = tf.reshape(logits, logits_shape)
    logits_transposed = tf.transpose(logits_reshaped, [1, 2, 3, 0])
    outputs_to_logits[output].append(logits_transposed)

    add_image_summaries(
        images[n * num_frames_per_video: (n+1) * num_frames_per_video],
        nn_features_n,
        logits_transposed,
        batch_size=1,
        prev_frame_nn_features=prev_frame_nn_features_n)
  if also_return_embeddings:
    return outputs_to_logits, embeddings
  else:
    return outputs_to_logits