def main():
  """Create the model and start the Inference process.
  """
  args = get_arguments()
    
  # Parse image processing arguments.
  input_size = parse_commastr(args.input_size)
  strides = parse_commastr(args.strides)
  assert(input_size is not None and strides is not None)
  h, w = input_size
  innet_size = (int(math.ceil(h/8)), int(math.ceil(w/8)))


  # Create queue coordinator.
  coord = tf.train.Coordinator()

  # Load the data reader.
  with tf.name_scope('create_inputs'):
    reader = VMFImageReader(
        args.data_dir,
        args.data_list,
        None,
        False, # No random scale.
        False, # No random mirror.
        False, # No random crop, center crop instead
        args.ignore_label,
        IMG_MEAN)

    image = reader.image
    label = reader.label
    image_list = reader.image_list
  image_batch = tf.expand_dims(image, dim=0)
  label_batch = tf.expand_dims(label, dim=0)

  # Create input tensor to the Network
  crop_image_batch = tf.placeholder(
      name='crop_image_batch',
      shape=[1,input_size[0],input_size[1],3],
      dtype=tf.float32)

  # Create network and output prediction.
  outputs = model(crop_image_batch,
                  args.embedding_dim,
                  False,
                  True)

  # Grab variable names which should be restored from checkpoints.
  restore_var = [
    v for v in tf.global_variables() if 'crop_image_batch' not in v.name]
    
  # Output predictions.
  output = outputs[0]
  output = tf.image.resize_bilinear(
      output,
      [input_size[0], input_size[1]])

  # Input full-sized embedding
  label_input = tf.placeholder(
      tf.int32, shape=[1, None, None, 1])
  embedding_input = tf.placeholder(
      tf.float32, shape=[1, None, None, args.embedding_dim])
  embedding = common_utils.normalize_embedding(embedding_input)
  loc_feature = tf.placeholder(
      tf.float32, shape=[1, None, None, 2])
  rgb_feature = tf.placeholder(
      tf.float32, shape=[1, None, None, 3])

  # Combine embedding with location features and kmeans
  shape = tf.shape(embedding)
  cluster_labels = common_utils.initialize_cluster_labels(
      [args.num_clusters, args.num_clusters],
      [shape[1], shape[2]])
  embedding = tf.reshape(embedding, [-1, args.embedding_dim])
  labels = tf.reshape(label_input, [-1])
  cluster_labels = tf.reshape(cluster_labels, [-1])
  location_features = tf.reshape(loc_feature, [-1, 2])
  rgb_features = common_utils.normalize_embedding(
      tf.reshape(rgb_feature, [-1, 3])) / args.embedding_dim

    # Collect pixels of valid semantic classes.
  valid_pixels = tf.where(
      tf.not_equal(labels, args.ignore_label))
  labels = tf.squeeze(tf.gather(labels, valid_pixels), axis=1)
  cluster_labels = tf.squeeze(tf.gather(cluster_labels, valid_pixels), axis=1)
  embedding = tf.squeeze(tf.gather(embedding, valid_pixels), axis=1)
  location_features = tf.squeeze(
      tf.gather(location_features, valid_pixels), axis=1)
  rgb_features = tf.squeeze(tf.gather(rgb_features, valid_pixels), axis=1)

  # Generate cluster labels via kmeans clustering.
  embedding_with_location = tf.concat(
      [embedding, location_features, rgb_features], 1)
  embedding_with_location = common_utils.normalize_embedding(
      embedding_with_location)
  cluster_labels = common_utils.kmeans_with_initial_labels(
      embedding_with_location,
      cluster_labels,
      args.num_clusters * args.num_clusters,
      args.kmeans_iterations)
  _, cluster_labels = tf.unique(cluster_labels)

  # Find pixels of majority semantic classes.
  select_pixels, prototype_labels = eval_utils.find_majority_label_index(
      labels, cluster_labels)

  # Calculate the prototype features.
  cluster_labels = tf.squeeze(tf.gather(cluster_labels, select_pixels), axis=1)
  embedding = tf.squeeze(tf.gather(embedding, select_pixels), axis=1)

  prototype_features = common_utils.calculate_prototypes_from_labels(
      embedding, cluster_labels)


  # Set up tf session and initialize variables. 
  config = tf.ConfigProto()
  config.gpu_options.allow_growth = True
  sess = tf.Session(config=config)
  init = tf.global_variables_initializer()
    
  sess.run(init)
  sess.run(tf.local_variables_initializer())
    
  # Load weights.
  loader = tf.train.Saver(var_list=restore_var)
  if args.restore_from is not None:
    load(loader, sess, args.restore_from)
    
  # Start queue threads.
  threads = tf.train.start_queue_runners(coord=coord, sess=sess)

  # Create directory for saving prototypes.
  save_dir = os.path.join(args.save_dir, 'prototypes')
  if not os.path.isdir(save_dir):
    os.makedirs(save_dir)
    
  # Iterate over testing steps.
  with open(args.data_list, 'r') as listf:
    num_steps = len(listf.read().split('\n'))-1


  pbar = tqdm(range(num_steps))
  for step in pbar:
    image_batch_np, label_batch_np = sess.run(
        [image_batch, label_batch])

    img_size = image_batch_np.shape
    padded_img_size = list(img_size)  # deep copy of img_size

    if input_size[0] > padded_img_size[1]:
      padded_img_size[1] = input_size[0]
    if input_size[1] > padded_img_size[2]:
      padded_img_size[2] = input_size[1]
    padded_img_batch = np.zeros(padded_img_size,
                                dtype=np.float32)
    img_h, img_w = img_size[1:3]
    padded_img_batch[:, :img_h, :img_w, :] = image_batch_np

    stride_h, stride_w = strides
    npatches_h = math.ceil(1.0*(padded_img_size[1]-input_size[0])/stride_h) + 1
    npatches_w = math.ceil(1.0*(padded_img_size[2]-input_size[1])/stride_w) + 1

    # Create the ending index of each patch.
    patch_indh = np.linspace(
        input_size[0], padded_img_size[1], npatches_h, dtype=np.int32)
    patch_indw = np.linspace(
        input_size[1], padded_img_size[2], npatches_w, dtype=np.int32)
    
    # Create embedding holder.
    padded_img_size[-1] = args.embedding_dim
    embedding_all_np = np.zeros(padded_img_size,
                                dtype=np.float32)
    for indh in patch_indh:
      for indw in patch_indw:
        sh, eh = indh-input_size[0], indh  # start & end ind of H
        sw, ew = indw-input_size[1], indw  # start & end ind of W
        cropimg_batch = padded_img_batch[:, sh:eh, sw:ew, :]

        embedding_np = sess.run(output, feed_dict={
            crop_image_batch: cropimg_batch})
        embedding_all_np[:, sh:eh, sw:ew, :] += embedding_np

    embedding_all_np = embedding_all_np[:, :img_h, :img_w, :]
    loc_feature_np = common_utils.generate_location_features_np([padded_img_size[1], padded_img_size[2]])
    feed_dict = {label_input: label_batch_np,
                 embedding_input: embedding_all_np,
                 loc_feature: loc_feature_np,
                 rgb_feature: padded_img_batch}

    (batch_prototype_features_np,
     batch_prototype_labels_np) = sess.run(
      [prototype_features, prototype_labels],
      feed_dict=feed_dict)

    if step == 0:
      prototype_features_np = batch_prototype_features_np
      prototype_labels_np = batch_prototype_labels_np
    else:
      prototype_features_np = np.concatenate(
          [prototype_features_np, batch_prototype_features_np], axis=0)
      prototype_labels_np = np.concatenate(
          [prototype_labels_np,
           batch_prototype_labels_np], axis=0)


  print ('Total number of prototypes extracted: ',
         len(prototype_labels_np))
  np.save(
      tf.gfile.Open('%s/%s.npy' % (save_dir, 'prototype_features'),
                     mode='w'), prototype_features_np)
  np.save(
      tf.gfile.Open('%s/%s.npy' % (save_dir, 'prototype_labels'),
                     mode='w'), prototype_labels_np)


  coord.request_stop()
  coord.join(threads)
예제 #2
0
def main():
  """Create the model and start the Inference process.
  """
  args = get_arguments()
    
  # Parse image processing arguments.
  input_size = parse_commastr(args.input_size)
  strides = parse_commastr(args.strides)
  assert(input_size is not None and strides is not None)
  h, w = input_size
  innet_size = (int(math.ceil(h/8)), int(math.ceil(w/8)))


  # Create queue coordinator.
  coord = tf.train.Coordinator()

  # Load the data reader.
  with tf.name_scope('create_inputs'):
    reader = VMFImageReader(
        args.data_dir,
        args.data_list,
        None,
        False, # No random scale.
        False, # No random mirror.
        False, # No random crop, center crop instead
        args.ignore_label,
        IMG_MEAN)

    image_batch = tf.expand_dims(reader.image, dim=0)
    label_batch = tf.expand_dims(reader.label, dim=0)
    cluster_label_batch = tf.expand_dims(reader.cluster_label, dim=0)
    loc_feature_batch = tf.expand_dims(reader.loc_feature, dim=0)

  # Create network and output prediction.
  outputs = model(image_batch,
                  args.embedding_dim,
                  False,
                  True)

  # Grab variable names which should be restored from checkpoints.
  restore_var = [
    v for v in tf.global_variables() if 'crop_image_batch' not in v.name]
    
  # Output predictions.
  output = outputs[0]
  output = tf.image.resize_bilinear(
      output,
      tf.shape(image_batch)[1:3,])
  embedding = common_utils.normalize_embedding(output)

  shape = embedding.get_shape().as_list()
  batch_size = shape[0]

  labels = label_batch
  initial_cluster_labels = cluster_label_batch[0, :, :]
  location_features = tf.reshape(loc_feature_batch[0, :, :], [-1, 2])

  prototype_feature_list = []
  prototype_label_list = []
  for bs in range(batch_size):
    cur_labels = tf.reshape(labels[bs], [-1])
    cur_cluster_labels = tf.reshape(initial_cluster_labels, [-1])
    cur_embedding = tf.reshape(embedding[bs], [-1, args.embedding_dim])

    (prototype_features,
     prototype_labels,
     _) = eval_utils.extract_trained_prototypes(
         cur_embedding, location_features, cur_cluster_labels,
         args.num_clusters * args.num_clusters,
         args.kmeans_iterations, cur_labels,
         1, args.ignore_label,
         'semantic')

    prototype_feature_list.append(prototype_features)
    prototype_label_list.append(prototype_labels)

  prototype_features = tf.concat(prototype_feature_list, axis=0)
  prototype_labels = tf.concat(prototype_label_list, axis=0)

    
  # Set up tf session and initialize variables. 
  config = tf.ConfigProto()
  config.gpu_options.allow_growth = True
  sess = tf.Session(config=config)
  init = tf.global_variables_initializer()
    
  sess.run(init)
  sess.run(tf.local_variables_initializer())
    
  # Load weights.
  loader = tf.train.Saver(var_list=restore_var)
  if args.restore_from is not None:
    load(loader, sess, args.restore_from)
    
  # Start queue threads.
  threads = tf.train.start_queue_runners(coord=coord, sess=sess)

  # Create directory for saving prototypes.
  save_dir = os.path.join(args.save_dir, 'prototypes')
  if not os.path.isdir(save_dir):
    os.makedirs(save_dir)
    
  # Iterate over testing steps.
  with open(args.data_list, 'r') as listf:
    num_steps = len(listf.read().split('\n'))-1

  for step in range(num_steps):
    (batch_prototype_features_np,
     batch_prototype_labels_np) = sess.run(
      [prototype_features, prototype_labels])

    if step == 0:
      prototype_features_np = batch_prototype_features_np
      prototype_labels_np = batch_prototype_labels_np
    else:
      prototype_features_np = np.concatenate(
          [prototype_features_np, batch_prototype_features_np], axis=0)
      prototype_labels_np = np.concatenate(
          [prototype_labels_np,
           batch_prototype_labels_np], axis=0)

    if (step + 1) % 100 == 0:
      print('Processed batches: ', (step + 1), '/', num_steps)

  print ('Total number of prototypes extracted: ',
         len(prototype_labels_np))
  np.save(
      tf.gfile.Open('%s/%s.npy' % (save_dir, 'prototype_features'),
                     mode='w'), prototype_features_np)
  np.save(
      tf.gfile.Open('%s/%s.npy' % (save_dir, 'prototype_labels'),
                     mode='w'), prototype_labels_np)


  coord.request_stop()
  coord.join(threads)
예제 #3
0
def main():
    """Create the model and start the Inference process.
  """
    args = get_arguments()

    # Create queue coordinator.
    coord = tf.train.Coordinator()

    # Load the data reader.
    with tf.name_scope('create_inputs'):
        reader = VMFImageReader(
            args.data_dir,
            args.data_list,
            None,
            False,  # No random scale.
            False,  # No random mirror.
            False,  # No random crop, center crop instead
            args.ignore_label,
            IMG_MEAN)

        image_list = reader.image_list
        image_batch = tf.expand_dims(reader.image, dim=0)
        label_batch = tf.expand_dims(reader.label, dim=0)
        cluster_label_batch = tf.expand_dims(reader.cluster_label, dim=0)
        loc_feature_batch = tf.expand_dims(reader.loc_feature, dim=0)
        height = reader.height
        width = reader.width

    # Create network and output prediction.
    outputs = model(image_batch, args.embedding_dim, False, True)

    # Grab variable names which should be restored from checkpoints.
    restore_var = [
        v for v in tf.global_variables() if 'crop_image_batch' not in v.name
    ]

    # Output predictions.
    output = outputs[0]
    output = tf.image.resize_bilinear(output, tf.shape(image_batch)[1:3, ])
    embedding = common_utils.normalize_embedding(output)

    # Prototype placeholders.
    prototype_features = tf.placeholder(tf.float32,
                                        shape=[None, args.embedding_dim])
    prototype_labels = tf.placeholder(tf.int32)

    # Combine embedding with location features.
    embedding_with_location = tf.concat([embedding, loc_feature_batch], 3)
    embedding_with_location = common_utils.normalize_embedding(
        embedding_with_location)

    # Kmeans clustering.
    cluster_labels = common_utils.kmeans(
        embedding_with_location, [args.num_clusters, args.num_clusters],
        args.kmeans_iterations)
    test_prototypes = common_utils.calculate_prototypes_from_labels(
        embedding, cluster_labels)

    # Predict semantic labels.
    semantic_predictions, _ = eval_utils.predict_semantic_instance_labels(
        cluster_labels, test_prototypes, prototype_features, prototype_labels,
        None, args.k_in_nearest_neighbors)
    semantic_predictions = tf.cast(semantic_predictions, tf.uint8)
    semantic_predictions = tf.squeeze(semantic_predictions)

    # Set up tf session and initialize variables.
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    init = tf.global_variables_initializer()

    sess.run(init)
    sess.run(tf.local_variables_initializer())

    # Load weights.
    loader = tf.train.Saver(var_list=restore_var)
    if args.restore_from is not None:
        load(loader, sess, args.restore_from)

    # Start queue threads.
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)

    # Get colormap.
    map_data = scipy.io.loadmat(args.colormap)
    key = os.path.basename(args.colormap).replace('.mat', '')
    colormap = map_data[key]
    colormap *= 255
    colormap = colormap.astype(np.uint8)

    # Create directory for saving predictions.
    pred_dir = os.path.join(args.save_dir, 'gray')
    color_dir = os.path.join(args.save_dir, 'color')
    if not os.path.isdir(pred_dir):
        os.makedirs(pred_dir)
    if not os.path.isdir(color_dir):
        os.makedirs(color_dir)

    # Iterate over testing steps.
    with open(args.data_list, 'r') as listf:
        num_steps = len(listf.read().split('\n')) - 1

    # Load prototype features and labels
    prototype_features_np = np.load(
        os.path.join(args.prototype_dir, 'prototype_features.npy'))
    prototype_labels_np = np.load(
        os.path.join(args.prototype_dir, 'prototype_labels.npy'))

    feed_dict = {
        prototype_features: prototype_features_np,
        prototype_labels: prototype_labels_np
    }

    for step in range(num_steps):
        semantic_predictions_np, height_np, width_np = sess.run(
            [semantic_predictions, height, width], feed_dict=feed_dict)

        semantic_predictions_np = semantic_predictions_np[:height_np, :
                                                          width_np]

        basename = os.path.basename(image_list[step])
        basename = basename.replace('jpg', 'png')

        predname = os.path.join(pred_dir, basename)
        Image.fromarray(semantic_predictions_np, mode='L').save(predname)

        colorname = os.path.join(color_dir, basename)
        color = colormap[semantic_predictions_np]
        Image.fromarray(color, mode='RGB').save(colorname)

        if (step + 1) % 100 == 0:
            print('Processed batches: ', (step + 1), '/', num_steps)

    coord.request_stop()
    coord.join(threads)
예제 #4
0
def main():
    """Create the model and start the Inference process.
  """
    args = get_arguments()

    # Create queue coordinator.
    coord = tf.train.Coordinator()

    # Load the data reader.
    with tf.name_scope('create_inputs'):
        reader = VMFImageReader(
            args.data_dir,
            args.data_list,
            None,
            False,  # No random scale.
            False,  # No random mirror.
            False,  # No random crop, center crop instead
            args.ignore_label,
            IMG_MEAN)

        image_list = reader.image_list
        image = reader.image
        cluster_label = reader.cluster_label
        loc_feature = reader.loc_feature
        height = reader.height
        width = reader.width

    # Create network and output prediction.
    outputs = model(tf.expand_dims(image, dim=0), args.embedding_dim, False,
                    True)

    # Grab variable names which should be restored from checkpoints.
    restore_var = [v for v in tf.global_variables()]

    # Output predictions.
    output = outputs[0]
    output = tf.image.resize_bilinear(output, tf.shape(image)[:2, ])
    embedding = common_utils.normalize_embedding(output)
    embedding = tf.squeeze(embedding, axis=0)

    image = image[:height, :width]
    embedding = tf.reshape(embedding[:height, :width],
                           [-1, args.embedding_dim])
    cluster_label = tf.reshape(cluster_label[:height, :width], [-1])
    loc_feature = tf.reshape(loc_feature[:height, :width], [-1, 2])

    # Prototype placeholders.
    prototype_features = tf.placeholder(tf.float32,
                                        shape=[None, args.embedding_dim])
    prototype_labels = tf.placeholder(tf.int32)

    # Combine embedding with location features and kmeans
    embedding_with_location = tf.concat([embedding, loc_feature], 1)
    embedding_with_location = common_utils.normalize_embedding(
        embedding_with_location)
    cluster_label = common_utils.kmeans_with_initial_labels(
        embedding_with_location, cluster_label,
        args.num_clusters * args.num_clusters, args.kmeans_iterations)
    _, cluster_labels = tf.unique(cluster_label)
    test_prototypes = common_utils.calculate_prototypes_from_labels(
        embedding, cluster_labels)

    cluster_labels = tf.reshape(cluster_labels, [height, width])

    # Predict semantic labels.
    similarities = tf.matmul(test_prototypes,
                             prototype_features,
                             transpose_b=True)
    _, k_predictions = tf.nn.top_k(similarities,
                                   k=args.k_in_nearest_neighbors,
                                   sorted=True)

    prototype_semantic_predictions = eval_utils.k_nearest_neighbors(
        k_predictions, prototype_labels)
    semantic_predictions = tf.gather(prototype_semantic_predictions,
                                     cluster_labels)
    #  semantic_predictions = tf.squeeze(semantic_predictions)

    # Visualize embedding using PCA
    embedding = vis_utils.pca(
        tf.reshape(embedding, [1, height, width, args.embedding_dim]))
    embedding = ((embedding - tf.reduce_min(embedding)) /
                 (tf.reduce_max(embedding) - tf.reduce_min(embedding)))
    embedding = tf.cast(embedding * 255, tf.uint8)
    embedding = tf.squeeze(embedding, axis=0)

    # Set up tf session and initialize variables.
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    init = tf.global_variables_initializer()

    sess.run(init)
    sess.run(tf.local_variables_initializer())

    # Load weights.
    loader = tf.train.Saver(var_list=restore_var)
    if args.restore_from is not None:
        load(loader, sess, args.restore_from)

    # Start queue threads.
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)

    # Get colormap.
    map_data = scipy.io.loadmat(args.colormap)
    key = os.path.basename(args.colormap).replace('.mat', '')
    colormap = map_data[key]
    colormap *= 255
    colormap = colormap.astype(np.uint8)

    # Create directory for saving predictions.
    pred_dir = os.path.join(args.save_dir, 'gray')
    color_dir = os.path.join(args.save_dir, 'color')
    cluster_dir = os.path.join(args.save_dir, 'cluster')
    embedding_dir = os.path.join(args.save_dir, 'embedding')
    patch_dir = os.path.join(args.save_dir, 'test_patches')
    if not os.path.isdir(pred_dir):
        os.makedirs(pred_dir)
    if not os.path.isdir(color_dir):
        os.makedirs(color_dir)
    if not os.path.isdir(cluster_dir):
        os.makedirs(cluster_dir)
    if not os.path.isdir(embedding_dir):
        os.makedirs(embedding_dir)
    if not os.path.isdir(patch_dir):
        os.makedirs(patch_dir)

    # Iterate over testing steps.
    with open(args.data_list, 'r') as listf:
        num_steps = len(listf.read().split('\n')) - 1

    # Load prototype features and labels
    prototype_features_np = np.load(
        os.path.join(args.prototype_dir, 'prototype_features.npy'))
    prototype_labels_np = np.load(
        os.path.join(args.prototype_dir, 'prototype_labels.npy'))

    feed_dict = {
        prototype_features: prototype_features_np,
        prototype_labels: prototype_labels_np
    }

    f = html_helper.open_html_for_write(
        os.path.join(args.save_dir, 'index.html'),
        'Visualization for Segment Collaging')
    for step in range(num_steps):
        image_np, semantic_predictions_np, cluster_labels_np, embedding_np, k_predictions_np = sess.run(
            [
                image, semantic_predictions, cluster_labels, embedding,
                k_predictions
            ],
            feed_dict=feed_dict)

        imgname = os.path.basename(image_list[step])
        basename = imgname.replace('jpg', 'png')

        predname = os.path.join(pred_dir, basename)
        Image.fromarray(semantic_predictions_np, mode='L').save(predname)

        colorname = os.path.join(color_dir, basename)
        color = colormap[semantic_predictions_np]
        Image.fromarray(color, mode='RGB').save(colorname)

        clustername = os.path.join(cluster_dir, basename)
        cluster = colormap[cluster_labels_np]
        Image.fromarray(cluster, mode='RGB').save(clustername)

        embeddingname = os.path.join(embedding_dir, basename)
        Image.fromarray(embedding_np, mode='RGB').save(embeddingname)

        image_np = (image_np + IMG_MEAN).astype(np.uint8)
        for i in range(np.max(cluster_labels_np) + 1):
            image_temp = copy.deepcopy(image_np)
            image_temp[cluster_labels_np != i] = 0
            coords = np.where(cluster_labels_np == i)
            crop = image_temp[np.min(coords[0]):np.max(coords[0]),
                              np.min(coords[1]):np.max(coords[1])]
            scipy.misc.imsave(
                patch_dir + '/' + basename + str(i).zfill(3) + '.png', crop)

        html_helper.write_vmf_to_html(
            f, './images/' + imgname, './labels/' + basename,
            './color/' + basename, './cluster/' + basename,
            './embedding/' + basename, './test_patches/' + basename,
            './patches/', k_predictions_np)

        if (step + 1) % 100 == 0:
            print('Processed batches: ', (step + 1), '/', num_steps)

    html_helper.close_html(f)
    coord.request_stop()
    coord.join(threads)
def main():
    """Create the model and start the inference process.
  """
    args = get_arguments()

    # Parse image processing arguments.
    input_size = parse_commastr(args.input_size)
    strides = parse_commastr(args.strides)
    assert (input_size is not None and strides is not None)
    h, w = input_size
    innet_size = (int(math.ceil(h / 8)), int(math.ceil(w / 8)))

    # Create queue coordinator.
    coord = tf.train.Coordinator()

    # Load the data reader.
    with tf.name_scope('create_inputs'):
        reader = VMFImageReader(
            args.data_dir,
            args.data_list,
            None,
            False,  # No random scale.
            False,  # No random mirror.
            False,  # No random crop, center crop instead
            args.ignore_label,
            IMG_MEAN)
        image = reader.image
        image_list = reader.image_list
    image_batch = tf.expand_dims(image, dim=0)

    # Create multi-scale augmented datas.
    rescale_image_batches = []
    is_flipped = []
    scales = [0.5, 0.75, 1, 1.25, 1.5, 1.75] if args.scale_aug else [1]
    for scale in scales:
        h_new = tf.to_int32(
            tf.multiply(tf.to_float(tf.shape(image_batch)[1]), scale))
        w_new = tf.to_int32(
            tf.multiply(tf.to_float(tf.shape(image_batch)[2]), scale))
        new_shape = tf.stack([h_new, w_new])
        new_image_batch = tf.image.resize_images(image_batch, new_shape)
        rescale_image_batches.append(new_image_batch)
        is_flipped.append(False)

    # Create horizontally flipped augmented datas.
    if args.flip_aug:
        for i in range(len(scales)):
            img = rescale_image_batches[i]
            is_flip = is_flipped[i]
            img = tf.squeeze(img, axis=0)
            flip_img = tf.image.flip_left_right(img)
            flip_img = tf.expand_dims(flip_img, axis=0)
            rescale_image_batches.append(flip_img)
            is_flipped.append(True)

    # Create input tensor to the Network
    crop_image_batch = tf.placeholder(
        name='crop_image_batch',
        shape=[1, input_size[0], input_size[1], 3],
        dtype=tf.float32)

    # Create network.
    outputs = model(crop_image_batch, args.embedding_dim, False, True)

    # Grab variable names which should be restored from checkpoints.
    restore_var = [
        v for v in tf.global_variables() if 'crop_image_batch' not in v.name
    ]

    # Output predictions.
    output = outputs[0]
    output = tf.image.resize_bilinear(output, [input_size[0], input_size[1]])

    embedding_input = tf.placeholder(tf.float32,
                                     shape=[1, None, None, args.embedding_dim])
    embedding = common_utils.normalize_embedding(embedding_input)
    loc_feature = tf.placeholder(tf.float32, shape=[1, None, None, 2])

    # Prototype placeholders.
    prototype_features = tf.placeholder(tf.float32,
                                        shape=[None, args.embedding_dim])
    prototype_labels = tf.placeholder(tf.int32)

    # Combine embedding with location features and kmeans
    shape = tf.shape(embedding)  #.get_shape().as_list()
    #  loc_feature = tf.expand_dims(
    #      common_utils.generate_location_features([shape[1], shape[2]], 'float'), 0)
    embedding_with_location = tf.concat([embedding, loc_feature], 3)
    embedding_with_location = common_utils.normalize_embedding(
        embedding_with_location)
    cluster_labels = common_utils.kmeans(
        embedding_with_location, [args.num_clusters, args.num_clusters],
        args.kmeans_iterations)
    _, cluster_labels = tf.unique(tf.reshape(cluster_labels, [-1]))
    test_prototypes = common_utils.calculate_prototypes_from_labels(
        embedding, cluster_labels)

    # Predict semantic labels.
    similarities = tf.matmul(test_prototypes,
                             prototype_features,
                             transpose_b=True)
    _, k_predictions = tf.nn.top_k(similarities,
                                   k=args.k_in_nearest_neighbors,
                                   sorted=True)
    k_predictions = tf.gather(prototype_labels, k_predictions)
    k_predictions = tf.gather(k_predictions, cluster_labels)
    k_predictions = tf.reshape(
        k_predictions,
        [shape[0], shape[1], shape[2], args.k_in_nearest_neighbors])

    # Set up tf session and initialize variables.
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    init = tf.global_variables_initializer()

    sess.run(init)
    sess.run(tf.local_variables_initializer())

    # Load weights.
    loader = tf.train.Saver(var_list=restore_var)
    if args.restore_from is not None:
        load(loader, sess, args.restore_from)

    # Start queue threads.
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)

    # Get colormap.
    map_data = scipy.io.loadmat(args.colormap)
    key = os.path.basename(args.colormap).replace('.mat', '')
    colormap = map_data[key]
    colormap *= 255
    colormap = colormap.astype(np.uint8)

    # Create directory for saving predictions.
    pred_dir = os.path.join(args.save_dir, 'gray')
    color_dir = os.path.join(args.save_dir, 'color')
    if not os.path.isdir(pred_dir):
        os.makedirs(pred_dir)
    if not os.path.isdir(color_dir):
        os.makedirs(color_dir)

    # Iterate over testing steps.
    with open(args.data_list, 'r') as listf:
        num_steps = len(listf.read().split('\n')) - 1

    # Load prototype features and labels
    prototype_features_np = np.load(
        os.path.join(args.prototype_dir, 'prototype_features.npy'))
    prototype_labels_np = np.load(
        os.path.join(args.prototype_dir, 'prototype_labels.npy'))
    feed_dict = {
        prototype_features: prototype_features_np,
        prototype_labels: prototype_labels_np
    }

    for step in range(num_steps):
        rescale_img_batches = sess.run(rescale_image_batches)
        # Final segmentation results (average across multiple scales).
        scale_ind = 2 if args.scale_aug else 0
        final_lab_size = list(rescale_img_batches[scale_ind].shape[1:])
        final_lab_size[-1] = args.num_classes
        final_lab_batch = np.zeros(final_lab_size)

        # Iterate over multiple scales.
        for img_batch, is_flip in zip(rescale_img_batches, is_flipped):
            img_size = img_batch.shape
            padimg_size = list(img_size)  # deep copy of img_size

            padimg_h, padimg_w = padimg_size[1:3]
            input_h, input_w = input_size

            if input_h > padimg_h:
                padimg_h = input_h
            if input_w > padimg_w:
                padimg_w = input_w
            # Update padded image size.
            padimg_size[1] = padimg_h
            padimg_size[2] = padimg_w
            padimg_batch = np.zeros(padimg_size, dtype=np.float32)
            img_h, img_w = img_size[1:3]
            padimg_batch[:, :img_h, :img_w, :] = img_batch

            stride_h, stride_w = strides
            npatches_h = math.ceil(1.0 * (padimg_h - input_h) / stride_h) + 1
            npatches_w = math.ceil(1.0 * (padimg_w - input_w) / stride_w) + 1

            # Create padded prediction array.
            pred_size = list(padimg_size)
            pred_size[-1] = args.num_classes
            predictions_np = np.zeros(pred_size, dtype=np.int32)

            # Create the ending index of each patch.
            patch_indh = np.linspace(input_h,
                                     padimg_h,
                                     npatches_h,
                                     dtype=np.int32)
            patch_indw = np.linspace(input_w,
                                     padimg_w,
                                     npatches_w,
                                     dtype=np.int32)

            pred_size[-1] = args.embedding_dim
            embedding_all_np = np.zeros(pred_size, dtype=np.float32)
            for indh in patch_indh:
                for indw in patch_indw:
                    sh, eh = indh - input_h, indh  # start & end ind of H
                    sw, ew = indw - input_w, indw  # start & end ind of W
                    cropimg_batch = padimg_batch[:, sh:eh, sw:ew, :]
                    #          feed_dict[crop_image_batch] = cropimg_batch

                    embedding_np = sess.run(
                        output, feed_dict={crop_image_batch: cropimg_batch})
                    embedding_all_np[:, sh:eh, sw:ew, :] += embedding_np

            loc_feature_np = common_utils.generate_location_features_np(
                [padimg_h, padimg_w])
            feed_dict[embedding_input] = embedding_all_np
            feed_dict[loc_feature] = loc_feature_np
            k_predictions_np = sess.run(k_predictions, feed_dict=feed_dict)
            for c in range(args.num_classes):
                predictions_np[:, :, :, c] += np.sum(
                    (k_predictions_np == c).astype(np.int), axis=3)

            predictions_np = predictions_np[0, :img_h, :img_w, :]
            lab_batch = predictions_np.astype(np.float32)
            #      lab_batch = np.argmax(pred_batch, axis=2).astype(np.float32)
            # Rescale prediction back to original resolution.
            lab_batch = cv2.resize(lab_batch,
                                   (final_lab_size[1], final_lab_size[0]),
                                   interpolation=cv2.INTER_LINEAR)
            if is_flip:
                # Flipped prediction back to original orientation.
                lab_batch = lab_batch[:, ::-1, :]
            final_lab_batch += lab_batch

        final_lab_ind = np.argmax(final_lab_batch, axis=-1)
        final_lab_ind = final_lab_ind.astype(np.uint8)

        basename = os.path.basename(image_list[step])
        basename = basename.replace('jpg', 'png')

        predname = os.path.join(pred_dir, basename)
        Image.fromarray(final_lab_ind, mode='L').save(predname)

        colorname = os.path.join(color_dir, basename)
        color = colormap[final_lab_ind]
        Image.fromarray(color, mode='RGB').save(colorname)

        if (step + 1) % 100 == 0:
            print('Processed batches: ', (step + 1), '/', num_steps)

    coord.request_stop()
    coord.join(threads)