def get_embeddings(images, model_options, embedding_dimension):
    """Extracts embedding vectors for images. Should only be used for inference.

  Args:
    images: A tensor of shape [batch, height, width, channels].
    model_options: A ModelOptions instance to configure models.
    embedding_dimension: Integer, the dimension of the embedding.

  Returns:
    embeddings: A tensor of shape [batch, height, width, embedding_dimension].
  """
    features, end_points = model.extract_features(images,
                                                  model_options,
                                                  is_training=False)

    if model_options.decoder_output_stride is not None:
        if model_options.crop_size is None:
            height = tf.shape(images)[1]
            width = tf.shape(images)[2]
        else:
            height, width = model_options.crop_size
        decoder_height = model.scale_dimension(
            height, 1.0 / model_options.decoder_output_stride)
        decoder_width = model.scale_dimension(
            width, 1.0 / model_options.decoder_output_stride)
        features = model.refine_by_decoder(
            features,
            end_points,
            decoder_height=decoder_height,
            decoder_width=decoder_width,
            decoder_use_separable_conv=model_options.
            decoder_use_separable_conv,
            model_variant=model_options.model_variant,
            is_training=False)

    with tf.variable_scope('embedding'):
        embeddings = split_separable_conv2d_with_identity_initializer(
            features, embedding_dimension, scope='split_separable_conv2d')
    return embeddings
def get_embeddings(images, model_options, embedding_dimension):
  """Extracts embedding vectors for images. Should only be used for inference.

  Args:
    images: A tensor of shape [batch, height, width, channels].
    model_options: A ModelOptions instance to configure models.
    embedding_dimension: Integer, the dimension of the embedding.

  Returns:
    embeddings: A tensor of shape [batch, height, width, embedding_dimension].
  """
  features, end_points = model.extract_features(
      images,
      model_options,
      is_training=False)

  if model_options.decoder_output_stride is not None:
    decoder_output_stride = min(model_options.decoder_output_stride)
    if model_options.crop_size is None:
      height = tf.shape(images)[1]
      width = tf.shape(images)[2]
    else:
      height, width = model_options.crop_size
    features = model.refine_by_decoder(
        features,
        end_points,
        crop_size=[height, width],
        decoder_output_stride=[decoder_output_stride],
        decoder_use_separable_conv=model_options.decoder_use_separable_conv,
        model_variant=model_options.model_variant,
        is_training=False)

  with tf.variable_scope('embedding'):
    embeddings = split_separable_conv2d_with_identity_initializer(
        features, embedding_dimension, scope='split_separable_conv2d')
  return embeddings
def train(l_args):
    """Trains the model."""
    if l_args.verbose:
        tf.logging.set_verbosity(tf.logging.INFO)

    # Create input data pipeline.
    x_train_files = sorted(
        glob.glob('/datatmp/Datasets/Cityscapes/leftImg8bit/train/*/*.png'))
    x_label_files = sorted(
        glob.glob(
            '/datatmp/Datasets/Cityscapes/gtFine/train/*/*_labelIds.png'))
    y_train_files = sorted(
        glob.glob(
            '/datatmp/Experiments/semantic_compression/{}/lambda_{}/leftImg8bit/train/*/*.png'
            .format(l_args.images_dir, l_args.lmbda)))

    print(len(x_train_files), len(y_train_files))
    assert (len(x_train_files) == len(y_train_files))
    assert (x_train_files[0].split("/")[-1] == y_train_files[0].split("/")[-1])
    assert (
        x_train_files[-1].split("/")[-1] == y_train_files[-1].split("/")[-1])

    print(x_train_files[0].split("/")[-1][:-16],
          x_label_files[0].split("/")[-1].split("_gtFine_labelIds.png")[0])
    print(x_train_files[-1].split("/")[-1][:-16],
          x_label_files[-1].split("/")[-1].split("_gtFine_labelIds.png")[0])
    assert (len(x_label_files) == len(x_train_files))
    assert (x_train_files[0].split("/")[-1][:-16] == x_label_files[0].split(
        "/")[-1].split("_gtFine_labelIds.png")[0])
    assert (x_train_files[-1].split("/")[-1][:-16] == x_label_files[-1].split(
        "/")[-1].split("_gtFine_labelIds.png")[0])

    train_dataset = tf.data.Dataset.from_tensor_slices(
        (x_train_files, x_label_files, y_train_files))
    train_dataset = train_dataset.shuffle(
        buffer_size=len(x_train_files)).repeat()
    train_dataset = train_dataset.map(
        read_pngs, num_parallel_calls=l_args.preprocess_threads)
    train_dataset = train_dataset.map(lambda x: tf.random_crop(
        x, [int(z) for z in l_args.patchsize.split(",")] + [7]))
    train_dataset = train_dataset.batch(l_args.batchsize)
    train_dataset = train_dataset.prefetch(l_args.batchsize)
    train_batch = train_dataset.make_one_shot_iterator().get_next()

    train_x, _, train_y = train_batch[:, :, :, :
                                      3], train_batch[:, :, :, 3:
                                                      4], train_batch[:, :, :,
                                                                      4:]
    scaled_train_x, scaled_train_y = train_x / 255., train_y / 255.

    model_options = common.ModelOptions(
        outputs_to_num_classes={common.OUTPUT_TYPE: 19},
        crop_size=[int(z) for z in l_args.patchsize.split(",")],
        atrous_rates=None,
        output_stride=16)

    x_features, _ = model.extract_features(train_x, model_options)
    exclude_list = ['global_step']
    variables_to_restore = tf.contrib.framework.get_variables_to_restore(
        exclude=exclude_list)
    seg_saver = tf.train.Saver(variables_to_restore)

    print(variables_to_restore)

    layers = RDN()
    scaled_x_tilde_hat = layers(scaled_train_y)
    x_tilde_hat = 255.0 * scaled_x_tilde_hat

    with tf.variable_scope(tf.get_variable_scope(), reuse=True):
        x_tilde_hat_features, _ = model.extract_features(
            x_tilde_hat, model_options)

    mse = tf.reduce_mean(
        tf.squared_difference(scaled_train_x, scaled_x_tilde_hat)) * 255**2
    ssim = tf.reduce_mean(
        1 - tf.image.ssim_multiscale(scaled_x_tilde_hat, scaled_train_x, 1))
    l1 = tf.reduce_mean(tf.math.abs(scaled_train_x - scaled_x_tilde_hat))

    distortion = {"mse": mse, "l1": l1, "msssim": ssim}[l_args.loss_type]
    distillation = tf.reduce_mean(
        tf.squared_difference(x_features, x_tilde_hat_features))

    train_loss = distortion + l_args.mu * distillation

    var_list = [
        var for var in tf.trainable_variables()
        if var not in variables_to_restore
    ]

    print()
    print()
    print(var_list)

    step = tf.train.get_or_create_global_step()
    main_optimizer = tf.train.AdamOptimizer(learning_rate=l_args.lr)
    train_op = main_optimizer.minimize(train_loss,
                                       var_list=var_list,
                                       global_step=step)

    log_all_summaries(train_x, scaled_x_tilde_hat, scaled_train_y, None, None,
                      train_loss, None, mse, None, ssim, "train")
    #log_all_summaries(val_x, valid_x_tilde, None, None, valid_loss, None, valid_mse, None, valid_ssim, "val")

    hooks = [
        tf.train.StopAtStepHook(last_step=l_args.last_step),
        tf.train.NanTensorHook(train_loss),
    ]

    def load_pretrain(scaffold, sess):
        seg_saver.restore(sess, save_path=PATH_TO_TRAINED_MODEL)

    hooks = [
        tf.train.StopAtStepHook(last_step=l_args.last_step),
        tf.train.NanTensorHook(train_loss),
    ]

    with tf.train.MonitoredTrainingSession(
            hooks=hooks,
            checkpoint_dir=l_args.checkpoint_dir,
            save_checkpoint_secs=1200,
            save_summaries_secs=60,
            scaffold=tf.train.Scaffold(init_fn=load_pretrain)) as sess:
        while not sess.should_stop():
            sess.run(train_op)
Exemple #4
0
def get_logits_with_matching(images,
                             model_options,
                             weight_decay=0.0001,
                             reuse=None,
                             is_training=False,
                             fine_tune_batch_norm=False,
                             reference_labels=None,
                             batch_size=None,
                             num_frames_per_video=None,
                             embedding_dimension=None,
                             max_neighbors_per_object=0,
                             k_nearest_neighbors=1,
                             use_softmax_feedback=True,
                             initial_softmax_feedback=None,
                             embedding_seg_feature_dimension=256,
                             embedding_seg_n_layers=4,
                             embedding_seg_kernel_size=7,
                             embedding_seg_atrous_rates=None,
                             normalize_nearest_neighbor_distances=True,
                             also_attend_to_previous_frame=True,
                             damage_initial_previous_frame_mask=False,
                             use_local_previous_frame_attention=True,
                             previous_frame_attention_window_size=15,
                             use_first_frame_matching=True,
                             also_return_embeddings=False,
                             ref_embeddings=None):
  """Gets the logits by atrous/image spatial pyramid pooling using attention.

  Args:
    images: A tensor of size [batch, height, width, channels].
    model_options: A ModelOptions instance to configure models.
    weight_decay: The weight decay for model variables.
    reuse: Reuse the model variables or not.
    is_training: Is training or not.
    fine_tune_batch_norm: Fine-tune the batch norm parameters or not.
    reference_labels: The segmentation labels of the reference frame on which
      attention is applied.
    batch_size: Integer, the number of videos on a batch
    num_frames_per_video: Integer, the number of frames per video
    embedding_dimension: Integer, the dimension of the embedding
    max_neighbors_per_object: Integer, the maximum number of candidates
      for the nearest neighbor query per object after subsampling.
      Can be 0 for no subsampling.
    k_nearest_neighbors: Integer, the number of nearest neighbors to use.
    use_softmax_feedback: Boolean, whether to give the softmax predictions of
      the last frame as additional input to the segmentation head.
    initial_softmax_feedback: List of Float32 tensors, or None. Can be used to
      initialize the softmax predictions used for the feedback loop.
      Only has an effect if use_softmax_feedback is True.
    embedding_seg_feature_dimension: Integer, the dimensionality used in the
      segmentation head layers.
    embedding_seg_n_layers: Integer, the number of layers in the segmentation
      head.
    embedding_seg_kernel_size: Integer, the kernel size used in the
      segmentation head.
    embedding_seg_atrous_rates: List of integers of length
      embedding_seg_n_layers, the atrous rates to use for the segmentation head.
    normalize_nearest_neighbor_distances: Boolean, whether to normalize the
      nearest neighbor distances to [0,1] using sigmoid, scale and shift.
    also_attend_to_previous_frame: Boolean, whether to also use nearest
      neighbor attention with respect to the previous frame.
    damage_initial_previous_frame_mask: Boolean, whether to artificially damage
      the initial previous frame mask. Only has an effect if
      also_attend_to_previous_frame is True.
    use_local_previous_frame_attention: Boolean, whether to restrict the
      previous frame attention to a local search window.
      Only has an effect, if also_attend_to_previous_frame is True.
    previous_frame_attention_window_size: Integer, the window size used for
      local previous frame attention, if use_local_previous_frame_attention
      is True.
    use_first_frame_matching: Boolean, whether to extract features by matching
      to the reference frame. This should always be true except for ablation
      experiments.
    also_return_embeddings: Boolean, whether to return the embeddings as well.
    ref_embeddings: Tuple of
      (first_frame_embeddings, previous_frame_embeddings),
      each of shape [batch, height, width, embedding_dimension], or None.
  Returns:
    outputs_to_logits: A map from output_type to logits.
    If also_return_embeddings is True, it will also return an embeddings
      tensor of shape [batch, height, width, embedding_dimension].
  """
  features, end_points = model.extract_features(
      images,
      model_options,
      weight_decay=weight_decay,
      reuse=reuse,
      is_training=is_training,
      fine_tune_batch_norm=fine_tune_batch_norm)

  if model_options.decoder_output_stride:
    decoder_output_stride = min(model_options.decoder_output_stride)
    if model_options.crop_size is None:
      height = tf.shape(images)[1]
      width = tf.shape(images)[2]
    else:
      height, width = model_options.crop_size
    decoder_height = model.scale_dimension(height, 1.0 / decoder_output_stride)
    decoder_width = model.scale_dimension(width, 1.0 / decoder_output_stride)
    features = model.refine_by_decoder(
        features,
        end_points,
        crop_size=[height, width],
        decoder_output_stride=[decoder_output_stride],
        decoder_use_separable_conv=model_options.decoder_use_separable_conv,
        model_variant=model_options.model_variant,
        weight_decay=weight_decay,
        reuse=reuse,
        is_training=is_training,
        fine_tune_batch_norm=fine_tune_batch_norm)

  with tf.variable_scope('embedding', reuse=reuse):
    embeddings = split_separable_conv2d_with_identity_initializer(
        features, embedding_dimension, scope='split_separable_conv2d')
    embeddings = tf.identity(embeddings, name='embeddings')
  scaled_reference_labels = tf.image.resize_nearest_neighbor(
      reference_labels,
      resolve_shape(embeddings, 4)[1:3],
      align_corners=True)
  h, w = decoder_height, decoder_width
  if num_frames_per_video is None:
    num_frames_per_video = tf.size(embeddings) // (
        batch_size * h * w * embedding_dimension)
  new_labels_shape = tf.stack([batch_size, -1, h, w, 1])
  reshaped_reference_labels = tf.reshape(scaled_reference_labels,
                                         new_labels_shape)
  new_embeddings_shape = tf.stack([batch_size,
                                   num_frames_per_video, h, w,
                                   embedding_dimension])
  reshaped_embeddings = tf.reshape(embeddings, new_embeddings_shape)
  all_nn_features = []
  all_ref_obj_ids = []
  # To keep things simple, we do all this separate for each sequence for now.
  for n in range(batch_size):
    embedding = reshaped_embeddings[n]
    if ref_embeddings is None:
      n_chunks = 100
      reference_embedding = embedding[0]
      if also_attend_to_previous_frame or use_softmax_feedback:
        queries_embedding = embedding[2:]
      else:
        queries_embedding = embedding[1:]
    else:
      if USE_CORRELATION_COST:
        n_chunks = 20
      else:
        n_chunks = 500
      reference_embedding = ref_embeddings[0][n]
      queries_embedding = embedding
    reference_labels = reshaped_reference_labels[n][0]
    nn_features_n, ref_obj_ids = nearest_neighbor_features_per_object(
        reference_embedding, queries_embedding, reference_labels,
        max_neighbors_per_object, k_nearest_neighbors, n_chunks=n_chunks)
    if normalize_nearest_neighbor_distances:
      nn_features_n = (tf.nn.sigmoid(nn_features_n) - 0.5) * 2
    all_nn_features.append(nn_features_n)
    all_ref_obj_ids.append(ref_obj_ids)

  feat_dim = resolve_shape(features)[-1]
  features = tf.reshape(features, tf.stack(
      [batch_size, num_frames_per_video, h, w, feat_dim]))
  if ref_embeddings is None:
    # Strip the features for the reference frame.
    if also_attend_to_previous_frame or use_softmax_feedback:
      features = features[:, 2:]
    else:
      features = features[:, 1:]

  # To keep things simple, we do all this separate for each sequence for now.
  outputs_to_logits = {output: [] for
                       output in model_options.outputs_to_num_classes}
  for n in range(batch_size):
    features_n = features[n]
    nn_features_n = all_nn_features[n]
    nn_features_n_tr = tf.transpose(nn_features_n, [3, 0, 1, 2, 4])
    n_objs = tf.shape(nn_features_n_tr)[0]
    # Repeat features for every object.
    features_n_tiled = tf.tile(features_n[tf.newaxis],
                               multiples=[n_objs, 1, 1, 1, 1])
    prev_frame_labels = None
    if also_attend_to_previous_frame:
      prev_frame_labels = reshaped_reference_labels[n, 1]
      if is_training and damage_initial_previous_frame_mask:
        # Damage the previous frame masks.
        prev_frame_labels = mask_damaging.damage_masks(prev_frame_labels,
                                                       dilate=False)
      tf.summary.image('prev_frame_labels',
                       tf.cast(prev_frame_labels[tf.newaxis],
                               tf.uint8) * 32)
      initial_softmax_feedback_n = create_initial_softmax_from_labels(
          prev_frame_labels, reshaped_reference_labels[n][0],
          decoder_output_stride=None, reduce_labels=True)
    elif initial_softmax_feedback is not None:
      initial_softmax_feedback_n = initial_softmax_feedback[n]
    else:
      initial_softmax_feedback_n = None
    if initial_softmax_feedback_n is None:
      last_softmax = tf.zeros((n_objs, h, w, 1), dtype=tf.float32)
    else:
      last_softmax = tf.transpose(initial_softmax_feedback_n, [2, 0, 1])[
          ..., tf.newaxis]
    assert len(model_options.outputs_to_num_classes) == 1
    output = list(model_options.outputs_to_num_classes.keys())[0]
    logits = []
    n_ref_frames = 1
    prev_frame_nn_features_n = None
    if also_attend_to_previous_frame or use_softmax_feedback:
      n_ref_frames += 1
    if ref_embeddings is not None:
      n_ref_frames = 0
    for t in range(num_frames_per_video - n_ref_frames):
      to_concat = [features_n_tiled[:, t]]
      if use_first_frame_matching:
        to_concat.append(nn_features_n_tr[:, t])
      if use_softmax_feedback:
        to_concat.append(last_softmax)
      if also_attend_to_previous_frame:
        assert normalize_nearest_neighbor_distances, (
            'previous frame attention currently only works when normalized '
            'distances are used')
        embedding = reshaped_embeddings[n]
        if ref_embeddings is None:
          last_frame_embedding = embedding[t + 1]
          query_embeddings = embedding[t + 2, tf.newaxis]
        else:
          last_frame_embedding = ref_embeddings[1][0]
          query_embeddings = embedding
        if use_local_previous_frame_attention:
          assert query_embeddings.shape[0] == 1
          prev_frame_nn_features_n = (
              local_previous_frame_nearest_neighbor_features_per_object(
                  last_frame_embedding,
                  query_embeddings[0],
                  prev_frame_labels,
                  all_ref_obj_ids[n],
                  max_distance=previous_frame_attention_window_size)
          )
        else:
          prev_frame_nn_features_n, _ = (
              nearest_neighbor_features_per_object(
                  last_frame_embedding, query_embeddings, prev_frame_labels,
                  max_neighbors_per_object, k_nearest_neighbors,
                  gt_ids=all_ref_obj_ids[n]))
          prev_frame_nn_features_n = (tf.nn.sigmoid(
              prev_frame_nn_features_n) - 0.5) * 2
        prev_frame_nn_features_n_sq = tf.squeeze(prev_frame_nn_features_n,
                                                 axis=0)
        prev_frame_nn_features_n_tr = tf.transpose(
            prev_frame_nn_features_n_sq, [2, 0, 1, 3])
        to_concat.append(prev_frame_nn_features_n_tr)
      features_n_concat_t = tf.concat(to_concat, axis=-1)
      embedding_seg_features_n_t = (
          create_embedding_segmentation_features(
              features_n_concat_t, embedding_seg_feature_dimension,
              embedding_seg_n_layers, embedding_seg_kernel_size,
              reuse or n > 0, atrous_rates=embedding_seg_atrous_rates))
      logits_t = model.get_branch_logits(
          embedding_seg_features_n_t,
          1,
          model_options.atrous_rates,
          aspp_with_batch_norm=model_options.aspp_with_batch_norm,
          kernel_size=model_options.logits_kernel_size,
          weight_decay=weight_decay,
          reuse=reuse or n > 0 or t > 0,
          scope_suffix=output
      )
      logits.append(logits_t)
      prev_frame_labels = tf.transpose(tf.argmax(logits_t, axis=0),
                                       [2, 0, 1])
      last_softmax = tf.nn.softmax(logits_t, axis=0)
    logits = tf.stack(logits, axis=1)
    logits_shape = tf.stack(
        [n_objs, num_frames_per_video - n_ref_frames] +
        resolve_shape(logits)[2:-1])
    logits_reshaped = tf.reshape(logits, logits_shape)
    logits_transposed = tf.transpose(logits_reshaped, [1, 2, 3, 0])
    outputs_to_logits[output].append(logits_transposed)

    add_image_summaries(
        images[n * num_frames_per_video: (n+1) * num_frames_per_video],
        nn_features_n,
        logits_transposed,
        batch_size=1,
        prev_frame_nn_features=prev_frame_nn_features_n)
  if also_return_embeddings:
    return outputs_to_logits, embeddings
  else:
    return outputs_to_logits
def get_logits_with_matching(images,
                             model_options,
                             weight_decay=0.0001,
                             reuse=None,
                             is_training=False,
                             fine_tune_batch_norm=False,
                             reference_labels=None,
                             batch_size=None,
                             num_frames_per_video=None,
                             embedding_dimension=None,
                             max_neighbors_per_object=0,
                             k_nearest_neighbors=1,
                             use_softmax_feedback=True,
                             initial_softmax_feedback=None,
                             embedding_seg_feature_dimension=256,
                             embedding_seg_n_layers=4,
                             embedding_seg_kernel_size=7,
                             embedding_seg_atrous_rates=None,
                             normalize_nearest_neighbor_distances=True,
                             also_attend_to_previous_frame=True,
                             damage_initial_previous_frame_mask=False,
                             use_local_previous_frame_attention=True,
                             previous_frame_attention_window_size=15,
                             use_first_frame_matching=True,
                             also_return_embeddings=False,
                             ref_embeddings=None):
  """Gets the logits by atrous/image spatial pyramid pooling using attention.

  Args:
    images: A tensor of size [batch, height, width, channels].
    model_options: A ModelOptions instance to configure models.
    weight_decay: The weight decay for model variables.
    reuse: Reuse the model variables or not.
    is_training: Is training or not.
    fine_tune_batch_norm: Fine-tune the batch norm parameters or not.
    reference_labels: The segmentation labels of the reference frame on which
      attention is applied.
    batch_size: Integer, the number of videos on a batch
    num_frames_per_video: Integer, the number of frames per video
    embedding_dimension: Integer, the dimension of the embedding
    max_neighbors_per_object: Integer, the maximum number of candidates
      for the nearest neighbor query per object after subsampling.
      Can be 0 for no subsampling.
    k_nearest_neighbors: Integer, the number of nearest neighbors to use.
    use_softmax_feedback: Boolean, whether to give the softmax predictions of
      the last frame as additional input to the segmentation head.
    initial_softmax_feedback: List of Float32 tensors, or None. Can be used to
      initialize the softmax predictions used for the feedback loop.
      Only has an effect if use_softmax_feedback is True.
    embedding_seg_feature_dimension: Integer, the dimensionality used in the
      segmentation head layers.
    embedding_seg_n_layers: Integer, the number of layers in the segmentation
      head.
    embedding_seg_kernel_size: Integer, the kernel size used in the
      segmentation head.
    embedding_seg_atrous_rates: List of integers of length
      embedding_seg_n_layers, the atrous rates to use for the segmentation head.
    normalize_nearest_neighbor_distances: Boolean, whether to normalize the
      nearest neighbor distances to [0,1] using sigmoid, scale and shift.
    also_attend_to_previous_frame: Boolean, whether to also use nearest
      neighbor attention with respect to the previous frame.
    damage_initial_previous_frame_mask: Boolean, whether to artificially damage
      the initial previous frame mask. Only has an effect if
      also_attend_to_previous_frame is True.
    use_local_previous_frame_attention: Boolean, whether to restrict the
      previous frame attention to a local search window.
      Only has an effect, if also_attend_to_previous_frame is True.
    previous_frame_attention_window_size: Integer, the window size used for
      local previous frame attention, if use_local_previous_frame_attention
      is True.
    use_first_frame_matching: Boolean, whether to extract features by matching
      to the reference frame. This should always be true except for ablation
      experiments.
    also_return_embeddings: Boolean, whether to return the embeddings as well.
    ref_embeddings: Tuple of
      (first_frame_embeddings, previous_frame_embeddings),
      each of shape [batch, height, width, embedding_dimension], or None.
  Returns:
    outputs_to_logits: A map from output_type to logits.
    If also_return_embeddings is True, it will also return an embeddings
      tensor of shape [batch, height, width, embedding_dimension].
  """
  features, end_points = model.extract_features(
      images,
      model_options,
      weight_decay=weight_decay,
      reuse=reuse,
      is_training=is_training,
      fine_tune_batch_norm=fine_tune_batch_norm)

  if model_options.decoder_output_stride:
    decoder_output_stride = min(model_options.decoder_output_stride)
    if model_options.crop_size is None:
      height = tf.shape(images)[1]
      width = tf.shape(images)[2]
    else:
      height, width = model_options.crop_size
    decoder_height = model.scale_dimension(height, 1.0 / decoder_output_stride)
    decoder_width = model.scale_dimension(width, 1.0 / decoder_output_stride)
    features = model.refine_by_decoder(
        features,
        end_points,
        crop_size=[height, width],
        decoder_output_stride=[decoder_output_stride],
        decoder_use_separable_conv=model_options.decoder_use_separable_conv,
        model_variant=model_options.model_variant,
        weight_decay=weight_decay,
        reuse=reuse,
        is_training=is_training,
        fine_tune_batch_norm=fine_tune_batch_norm)

  with tf.variable_scope('embedding', reuse=reuse):
    embeddings = split_separable_conv2d_with_identity_initializer(
        features, embedding_dimension, scope='split_separable_conv2d')
    embeddings = tf.identity(embeddings, name='embeddings')
  scaled_reference_labels = tf.image.resize_nearest_neighbor(
      reference_labels,
      resolve_shape(embeddings, 4)[1:3],
      align_corners=True)
  h, w = decoder_height, decoder_width
  if num_frames_per_video is None:
    num_frames_per_video = tf.size(embeddings) // (
        batch_size * h * w * embedding_dimension)
  new_labels_shape = tf.stack([batch_size, -1, h, w, 1])
  reshaped_reference_labels = tf.reshape(scaled_reference_labels,
                                         new_labels_shape)
  new_embeddings_shape = tf.stack([batch_size,
                                   num_frames_per_video, h, w,
                                   embedding_dimension])
  reshaped_embeddings = tf.reshape(embeddings, new_embeddings_shape)
  all_nn_features = []
  all_ref_obj_ids = []
  # To keep things simple, we do all this separate for each sequence for now.
  for n in range(batch_size):
    embedding = reshaped_embeddings[n]
    if ref_embeddings is None:
      n_chunks = 100
      reference_embedding = embedding[0]
      if also_attend_to_previous_frame or use_softmax_feedback:
        queries_embedding = embedding[2:]
      else:
        queries_embedding = embedding[1:]
    else:
      if USE_CORRELATION_COST:
        n_chunks = 20
      else:
        n_chunks = 500
      reference_embedding = ref_embeddings[0][n]
      queries_embedding = embedding
    reference_labels = reshaped_reference_labels[n][0]
    nn_features_n, ref_obj_ids = nearest_neighbor_features_per_object(
        reference_embedding, queries_embedding, reference_labels,
        max_neighbors_per_object, k_nearest_neighbors, n_chunks=n_chunks)
    if normalize_nearest_neighbor_distances:
      nn_features_n = (tf.nn.sigmoid(nn_features_n) - 0.5) * 2
    all_nn_features.append(nn_features_n)
    all_ref_obj_ids.append(ref_obj_ids)

  feat_dim = resolve_shape(features)[-1]
  features = tf.reshape(features, tf.stack(
      [batch_size, num_frames_per_video, h, w, feat_dim]))
  if ref_embeddings is None:
    # Strip the features for the reference frame.
    if also_attend_to_previous_frame or use_softmax_feedback:
      features = features[:, 2:]
    else:
      features = features[:, 1:]

  # To keep things simple, we do all this separate for each sequence for now.
  outputs_to_logits = {output: [] for
                       output in model_options.outputs_to_num_classes}
  for n in range(batch_size):
    features_n = features[n]
    nn_features_n = all_nn_features[n]
    nn_features_n_tr = tf.transpose(nn_features_n, [3, 0, 1, 2, 4])
    n_objs = tf.shape(nn_features_n_tr)[0]
    # Repeat features for every object.
    features_n_tiled = tf.tile(features_n[tf.newaxis],
                               multiples=[n_objs, 1, 1, 1, 1])
    prev_frame_labels = None
    if also_attend_to_previous_frame:
      prev_frame_labels = reshaped_reference_labels[n, 1]
      if is_training and damage_initial_previous_frame_mask:
        # Damage the previous frame masks.
        prev_frame_labels = mask_damaging.damage_masks(prev_frame_labels,
                                                       dilate=False)
      tf.summary.image('prev_frame_labels',
                       tf.cast(prev_frame_labels[tf.newaxis],
                               tf.uint8) * 32)
      initial_softmax_feedback_n = create_initial_softmax_from_labels(
          prev_frame_labels, reshaped_reference_labels[n][0],
          decoder_output_stride=None, reduce_labels=True)
    elif initial_softmax_feedback is not None:
      initial_softmax_feedback_n = initial_softmax_feedback[n]
    else:
      initial_softmax_feedback_n = None
    if initial_softmax_feedback_n is None:
      last_softmax = tf.zeros((n_objs, h, w, 1), dtype=tf.float32)
    else:
      last_softmax = tf.transpose(initial_softmax_feedback_n, [2, 0, 1])[
          ..., tf.newaxis]
    assert len(model_options.outputs_to_num_classes) == 1
    output = model_options.outputs_to_num_classes.keys()[0]
    logits = []
    n_ref_frames = 1
    prev_frame_nn_features_n = None
    if also_attend_to_previous_frame or use_softmax_feedback:
      n_ref_frames += 1
    if ref_embeddings is not None:
      n_ref_frames = 0
    for t in range(num_frames_per_video - n_ref_frames):
      to_concat = [features_n_tiled[:, t]]
      if use_first_frame_matching:
        to_concat.append(nn_features_n_tr[:, t])
      if use_softmax_feedback:
        to_concat.append(last_softmax)
      if also_attend_to_previous_frame:
        assert normalize_nearest_neighbor_distances, (
            'previous frame attention currently only works when normalized '
            'distances are used')
        embedding = reshaped_embeddings[n]
        if ref_embeddings is None:
          last_frame_embedding = embedding[t + 1]
          query_embeddings = embedding[t + 2, tf.newaxis]
        else:
          last_frame_embedding = ref_embeddings[1][0]
          query_embeddings = embedding
        if use_local_previous_frame_attention:
          assert query_embeddings.shape[0] == 1
          prev_frame_nn_features_n = (
              local_previous_frame_nearest_neighbor_features_per_object(
                  last_frame_embedding,
                  query_embeddings[0],
                  prev_frame_labels,
                  all_ref_obj_ids[n],
                  max_distance=previous_frame_attention_window_size)
          )
        else:
          prev_frame_nn_features_n, _ = (
              nearest_neighbor_features_per_object(
                  last_frame_embedding, query_embeddings, prev_frame_labels,
                  max_neighbors_per_object, k_nearest_neighbors,
                  gt_ids=all_ref_obj_ids[n]))
          prev_frame_nn_features_n = (tf.nn.sigmoid(
              prev_frame_nn_features_n) - 0.5) * 2
        prev_frame_nn_features_n_sq = tf.squeeze(prev_frame_nn_features_n,
                                                 axis=0)
        prev_frame_nn_features_n_tr = tf.transpose(
            prev_frame_nn_features_n_sq, [2, 0, 1, 3])
        to_concat.append(prev_frame_nn_features_n_tr)
      features_n_concat_t = tf.concat(to_concat, axis=-1)
      embedding_seg_features_n_t = (
          create_embedding_segmentation_features(
              features_n_concat_t, embedding_seg_feature_dimension,
              embedding_seg_n_layers, embedding_seg_kernel_size,
              reuse or n > 0, atrous_rates=embedding_seg_atrous_rates))
      logits_t = model.get_branch_logits(
          embedding_seg_features_n_t,
          1,
          model_options.atrous_rates,
          aspp_with_batch_norm=model_options.aspp_with_batch_norm,
          kernel_size=model_options.logits_kernel_size,
          weight_decay=weight_decay,
          reuse=reuse or n > 0 or t > 0,
          scope_suffix=output
      )
      logits.append(logits_t)
      prev_frame_labels = tf.transpose(tf.argmax(logits_t, axis=0),
                                       [2, 0, 1])
      last_softmax = tf.nn.softmax(logits_t, axis=0)
    logits = tf.stack(logits, axis=1)
    logits_shape = tf.stack(
        [n_objs, num_frames_per_video - n_ref_frames] +
        resolve_shape(logits)[2:-1])
    logits_reshaped = tf.reshape(logits, logits_shape)
    logits_transposed = tf.transpose(logits_reshaped, [1, 2, 3, 0])
    outputs_to_logits[output].append(logits_transposed)

    add_image_summaries(
        images[n * num_frames_per_video: (n+1) * num_frames_per_video],
        nn_features_n,
        logits_transposed,
        batch_size=1,
        prev_frame_nn_features=prev_frame_nn_features_n)
  if also_return_embeddings:
    return outputs_to_logits, embeddings
  else:
    return outputs_to_logits
def train(l_args):
    """Trains the model."""
    if l_args.verbose:
        tf.logging.set_verbosity(tf.logging.INFO)

    # Create input data pipeline.
    x_train_files = sorted(
        glob.glob('/datatmp/Datasets/Cityscapes/leftImg8bit/train/*/*.png'))
    x_label_files = sorted(
        glob.glob(
            '/datatmp/Datasets/Cityscapes/gtFine/train/*/*_labelIds.png'))
    y_train_files = sorted(
        glob.glob(
            '/datatmp/Experiments/semantic_compression/{}/lambda_{}/leftImg8bit/train/*/*.png'
            .format(l_args.images_dir, l_args.lmbda)))

    print(len(x_train_files), len(y_train_files))
    assert (len(x_train_files) == len(y_train_files))
    assert (x_train_files[0].split("/")[-1] == y_train_files[0].split("/")[-1])
    assert (
        x_train_files[-1].split("/")[-1] == y_train_files[-1].split("/")[-1])

    print(x_train_files[0].split("/")[-1][:-16],
          x_label_files[0].split("/")[-1].split("_gtFine_labelIds.png")[0])
    print(x_train_files[-1].split("/")[-1][:-16],
          x_label_files[-1].split("/")[-1].split("_gtFine_labelIds.png")[0])
    assert (len(x_label_files) == len(x_train_files))
    assert (x_train_files[0].split("/")[-1][:-16] == x_label_files[0].split(
        "/")[-1].split("_gtFine_labelIds.png")[0])
    assert (x_train_files[-1].split("/")[-1][:-16] == x_label_files[-1].split(
        "/")[-1].split("_gtFine_labelIds.png")[0])

    train_dataset = tf.data.Dataset.from_tensor_slices(
        (x_train_files, x_label_files, y_train_files))
    train_dataset = train_dataset.shuffle(
        buffer_size=len(x_train_files)).repeat()
    train_dataset = train_dataset.map(read_pngs,
                                      num_parallel_calls=min(
                                          l_args.preprocess_threads,
                                          l_args.batchsize))
    if l_args.resize_images:
        train_dataset = train_dataset.map(
            lambda x: tf.image.resize_images(x, [512, 1024]))
    train_dataset = train_dataset.map(lambda x: tf.random_crop(
        x, [int(z) for z in l_args.patchsize.split(",")] + [7]))
    train_dataset = train_dataset.batch(l_args.batchsize)
    train_dataset = train_dataset.prefetch(l_args.batchsize)
    train_batch = train_dataset.make_one_shot_iterator().get_next()

    train_x, _, train_y = train_batch[:, :, :, :
                                      3], train_batch[:, :, :, 3:
                                                      4], train_batch[:, :, :,
                                                                      4:]
    scaled_train_x, scaled_train_y = train_x / 255., train_y / 255.

    x_val_files = sorted(
        glob.glob('/datatmp/Datasets/Cityscapes/leftImg8bit/val/*/*.png'))
    x_label_files = sorted(
        glob.glob('/datatmp/Datasets/Cityscapes/gtFine/val/*/*_labelIds.png'))
    y_val_files = sorted(
        glob.glob(
            '/datatmp/Experiments/semantic_compression/{}/lambda_{}/leftImg8bit/val/*/*.png'
            .format(l_args.images_dir, l_args.lmbda)))

    print(len(x_val_files), len(y_val_files))
    assert (len(x_val_files) == len(y_val_files))
    assert (x_val_files[0].split("/")[-1] == y_val_files[0].split("/")[-1])
    assert (x_val_files[-1].split("/")[-1] == y_val_files[-1].split("/")[-1])

    print(x_val_files[0].split("/")[-1][:-16],
          x_label_files[0].split("/")[-1].split("_gtFine_labelIds.png")[0])
    print(x_val_files[-1].split("/")[-1][:-16],
          x_label_files[-1].split("/")[-1].split("_gtFine_labelIds.png")[0])
    assert (len(x_label_files) == len(x_val_files))
    assert (x_val_files[0].split("/")[-1][:-16] == x_label_files[0].split("/")
            [-1].split("_gtFine_labelIds.png")[0])
    assert (x_val_files[-1].split("/")[-1][:-16] == x_label_files[-1].split(
        "/")[-1].split("_gtFine_labelIds.png")[0])

    def set_shape(x):
        x.set_shape([1024, 2048, 7])
        return x

    val_dataset = tf.data.Dataset.from_tensor_slices(
        (x_val_files, x_label_files, y_val_files))
    val_dataset = val_dataset.map(read_pngs, num_parallel_calls=1)
    val_dataset = val_dataset.map(set_shape, num_parallel_calls=1)
    val_dataset = val_dataset.batch(1)
    val_dataset = val_dataset.prefetch(1)
    val_batch = val_dataset.make_one_shot_iterator().get_next()

    val_x, _, val_y = val_batch[:, :, :, :3], val_batch[:, :, :, 3:
                                                        4], val_batch[:, :, :,
                                                                      4:]
    scaled_val_x, scaled_val_y = val_x / 255., val_y / 255.

    model_options = common.ModelOptions(
        outputs_to_num_classes={common.OUTPUT_TYPE: 19},
        crop_size=[int(z) for z in l_args.patchsize.split(",")],
        atrous_rates=None,
        output_stride=16)

    x_features, _ = model.extract_features(train_x, model_options)
    exclude_list = ['global_step']
    variables_to_restore = tf.contrib.framework.get_variables_to_restore(
        exclude=exclude_list)
    seg_saver = tf.train.Saver(variables_to_restore)

    print(variables_to_restore)

    rdn = RDN()
    scaled_x_tilde_hat = rdn(scaled_train_y)
    x_tilde_hat = 255.0 * scaled_x_tilde_hat

    with tf.variable_scope(tf.get_variable_scope(), reuse=True):
        x_tilde_hat_features, _ = model.extract_features(
            x_tilde_hat, model_options)

    var_list = [
        var for var in tf.trainable_variables()
        if var not in variables_to_restore
    ]

    print()
    print()
    print(var_list)

    discriminator = PatchDiscriminator(l_args.disc_patchsize)
    fake = tf.reduce_mean(discriminator(scaled_x_tilde_hat, scaled_train_y))
    real = tf.reduce_mean(discriminator(scaled_train_x, scaled_train_y))

    generator_loss = -1.0 * fake
    wasserstein_distance = real - fake
    discriminator_loss = -1.0 * wasserstein_distance

    mse = tf.reduce_mean(
        tf.squared_difference(scaled_train_x, scaled_x_tilde_hat)) * 255**2
    ssim = tf.reduce_mean(
        1 - tf.image.ssim_multiscale(scaled_x_tilde_hat, scaled_train_x, 1))
    l1 = tf.reduce_mean(tf.math.abs(scaled_train_x - scaled_x_tilde_hat))

    distortion = {
        "mse": mse,
        "l1": l1,
        "msssim": ssim,
        "msssim_l1": 2 * l1 + ssim
    }[l_args.loss_type]
    distillation = tf.reduce_mean(
        tf.squared_difference(x_features, x_tilde_hat_features))

    train_loss = l_args.rho * distortion + l_args.mu * distillation + generator_loss

    rdn_weights = var_list
    discriminator_weights = discriminator.weights

    print()
    print()
    print(discriminator_weights)

    print()
    print()
    print([
        var for var in tf.trainable_variables()
        if var not in variables_to_restore + rdn_weights +
        discriminator_weights
    ])

    step = tf.train.get_or_create_global_step()
    generator_optimizer = tf.train.AdamOptimizer(learning_rate=l_args.lr,
                                                 beta1=0,
                                                 beta2=0.9)
    generator_op = generator_optimizer.minimize(train_loss,
                                                var_list=rdn_weights,
                                                global_step=step)

    discriminator_optimizer = tf.train.AdamOptimizer(
        learning_rate=l_args.disc_lr, beta1=0, beta2=0.9)
    discriminator_op = discriminator_optimizer.minimize(
        discriminator_loss, var_list=discriminator_weights)

    train_summary = log_all_summaries(train_x, scaled_x_tilde_hat,
                                      scaled_train_y, None, None, train_loss,
                                      None, mse, None, ssim, distillation,
                                      wasserstein_distance, l1, "train")

    scaled_x_val_hat = rdn(scaled_val_y)
    val_fake = tf.reduce_mean(discriminator(scaled_x_val_hat, scaled_val_y))
    val_real = tf.reduce_mean(discriminator(scaled_val_x, scaled_val_y))

    val_generator_loss = -1.0 * val_fake
    val_wasserstein = val_real - val_fake

    with tf.variable_scope(tf.get_variable_scope(), reuse=True):
        val_x_features, _ = model.extract_features(val_x, model_options)
        x_val_hat_features, _ = model.extract_features(
            255.0 * scaled_x_val_hat, model_options)

    val_distillation = tf.reduce_mean(
        tf.squared_difference(val_x_features, x_val_hat_features))
    val_mse = tf.reduce_mean(
        tf.squared_difference(scaled_val_x, scaled_x_val_hat)) * 255**2
    val_ssim = tf.reduce_mean(
        1 - tf.image.ssim_multiscale(scaled_x_val_hat, scaled_val_x, 1))
    val_l1 = tf.reduce_mean(tf.math.abs(scaled_val_x - scaled_x_val_hat))
    val_distortion = {
        "mse": val_mse,
        "l1": val_l1,
        "msssim": val_ssim,
        "msssim_l1": 2 * val_l1 + val_ssim
    }[l_args.loss_type]
    val_loss = l_args.rho * val_distortion + l_args.mu * val_distillation + val_generator_loss

    #valid_summary = log_all_summaries(val_x, scaled_x_val_hat, scaled_val_y,
    #                 None, None, val_loss, None, val_mse, None, val_ssim, val_distillation, val_wasserstein, val_l1, "val")

    def load_pretrain(scaffold, sess):
        seg_saver.restore(sess, save_path=PATH_TO_TRAINED_MODEL)

    hooks = [
        tf.train.StopAtStepHook(last_step=l_args.last_step),
        tf.train.NanTensorHook(train_loss),
        #tf.train.SummarySaverHook(save_secs=120, output_dir=l_args.checkpoint_dir,summary_op=valid_summary),
        tf.train.SummarySaverHook(save_secs=60,
                                  output_dir=l_args.checkpoint_dir,
                                  summary_op=train_summary),
    ]

    with tf.train.MonitoredTrainingSession(
            hooks=hooks,
            checkpoint_dir=l_args.checkpoint_dir,
            save_checkpoint_secs=1200,
            save_summaries_steps=None,
            save_summaries_secs=None,
            scaffold=tf.train.Scaffold(init_fn=load_pretrain)) as sess:
        while not sess.should_stop():
            sess.run(discriminator_op)
            sess.run(generator_op)
def experiment(l_args):
    x_val_files = sorted(
        glob.glob('/datatmp/Datasets/Cityscapes/leftImg8bit/val/*/*.png'))
    x_label_files = sorted(
        glob.glob('/datatmp/Datasets/Cityscapes/gtFine/val/*/*_labelIds.png'))
    y_val_files = sorted(
        glob.glob(
            '/datatmp/Experiments/semantic_compression/{}/lambda_{}/leftImg8bit/val/*/*.png'
            .format(l_args.images_dir, l_args.lmbda)))

    print(len(x_val_files), len(y_val_files))
    assert (len(x_val_files) == len(y_val_files))
    assert (x_val_files[0].split("/")[-1] == y_val_files[0].split("/")[-1])
    assert (x_val_files[-1].split("/")[-1] == y_val_files[-1].split("/")[-1])

    print(x_val_files[0].split("/")[-1][:-16],
          x_label_files[0].split("/")[-1].split("_gtFine_labelIds.png")[0])
    print(x_val_files[-1].split("/")[-1][:-16],
          x_label_files[-1].split("/")[-1].split("_gtFine_labelIds.png")[0])
    assert (len(x_label_files) == len(x_val_files))
    assert (x_val_files[0].split("/")[-1][:-16] == x_label_files[0].split("/")
            [-1].split("_gtFine_labelIds.png")[0])
    assert (x_val_files[-1].split("/")[-1][:-16] == x_label_files[-1].split(
        "/")[-1].split("_gtFine_labelIds.png")[0])

    def set_shape(x):
        x.set_shape([1024, 2048, 7])
        return x

    val_dataset = tf.data.Dataset.from_tensor_slices(
        (x_val_files, x_label_files, y_val_files))
    val_dataset = val_dataset.map(read_pngs,
                                  num_parallel_calls=l_args.preprocess_threads)
    val_dataset = val_dataset.map(set_shape,
                                  num_parallel_calls=l_args.preprocess_threads)
    val_dataset = val_dataset.batch(1)
    val_dataset = val_dataset.prefetch(1)
    val_batch = val_dataset.make_one_shot_iterator().get_next()

    val_x, _, val_y = val_batch[:, :, :, :3], val_batch[:, :, :, 3:
                                                        4], val_batch[:, :, :,
                                                                      4:]
    scaled_val_x, scaled_val_y = val_x / 255., val_y / 255.

    model_options = common.ModelOptions(
        outputs_to_num_classes={common.OUTPUT_TYPE: 19},
        crop_size=[int(z) for z in l_args.patchsize.split(",")],
        atrous_rates=None,
        output_stride=16)

    x_features, _ = model.extract_features(val_x, model_options)
    with tf.variable_scope(tf.get_variable_scope(), reuse=True):
        y_features, _ = model.extract_features(val_y, model_options)

    exclude_list = ['global_step']
    variables_to_restore = tf.contrib.framework.get_variables_to_restore(
        exclude=exclude_list)
    seg_saver = tf.train.Saver(variables_to_restore)

    diff_features = x_features - y_features

    sess = tf.Session()
    seg_saver.restore(sess, save_path=PATH_TO_TRAINED_MODEL)

    #while not sess.should_stop():
    feature_index = 256
    for i in range(500):
        x, y, fx, fy, dxy = sess.run(
            [val_x, val_y, x_features, y_features, diff_features])
        print(i)
        rxy = np.maximum(fx, 1e-6) / np.maximum(fy, 1e-6)
        #for i in range(fx.shape[-1]):
        fig = plt.figure(figsize=(30, 20))
        ax1 = fig.add_subplot(3, 2, 1)
        ax1.imshow(x[0].astype(np.uint8))
        ax2 = fig.add_subplot(3, 2, 2)
        ax2.imshow(y[0].astype(np.uint8))
        ax3 = fig.add_subplot(3, 2, 3)
        ax3.imshow(fx[0, :, :, feature_index] /
                   np.max(fx[0, :, :, feature_index]))
        ax4 = fig.add_subplot(3, 2, 4)
        ax4.imshow(fy[0, :, :, feature_index] /
                   np.max(fy[0, :, :, feature_index]))
        ax5 = fig.add_subplot(3, 2, 5)
        ax5.imshow(
            np.abs(dxy[0, :, :, feature_index]) /
            np.max(np.abs(dxy[0, :, :, feature_index])))
        ax6 = fig.add_subplot(3, 2, 6)
        ax6.imshow(rxy[0, :, :, feature_index] /
                   np.max(rxy[0, :, :, feature_index]))
        plt.savefig("feats/{}.png".format(i))