def main(): """Create the model and start the Inference process. """ args = get_arguments() # Parse image processing arguments. input_size = parse_commastr(args.input_size) strides = parse_commastr(args.strides) assert(input_size is not None and strides is not None) h, w = input_size innet_size = (int(math.ceil(h/8)), int(math.ceil(w/8))) # Create queue coordinator. coord = tf.train.Coordinator() # Load the data reader. with tf.name_scope('create_inputs'): reader = VMFImageReader( args.data_dir, args.data_list, None, False, # No random scale. False, # No random mirror. False, # No random crop, center crop instead args.ignore_label, IMG_MEAN) image = reader.image label = reader.label image_list = reader.image_list image_batch = tf.expand_dims(image, dim=0) label_batch = tf.expand_dims(label, dim=0) # Create input tensor to the Network crop_image_batch = tf.placeholder( name='crop_image_batch', shape=[1,input_size[0],input_size[1],3], dtype=tf.float32) # Create network and output prediction. outputs = model(crop_image_batch, args.embedding_dim, False, True) # Grab variable names which should be restored from checkpoints. restore_var = [ v for v in tf.global_variables() if 'crop_image_batch' not in v.name] # Output predictions. output = outputs[0] output = tf.image.resize_bilinear( output, [input_size[0], input_size[1]]) # Input full-sized embedding label_input = tf.placeholder( tf.int32, shape=[1, None, None, 1]) embedding_input = tf.placeholder( tf.float32, shape=[1, None, None, args.embedding_dim]) embedding = common_utils.normalize_embedding(embedding_input) loc_feature = tf.placeholder( tf.float32, shape=[1, None, None, 2]) rgb_feature = tf.placeholder( tf.float32, shape=[1, None, None, 3]) # Combine embedding with location features and kmeans shape = tf.shape(embedding) cluster_labels = common_utils.initialize_cluster_labels( [args.num_clusters, args.num_clusters], [shape[1], shape[2]]) embedding = tf.reshape(embedding, [-1, args.embedding_dim]) labels = tf.reshape(label_input, [-1]) cluster_labels = tf.reshape(cluster_labels, [-1]) location_features = tf.reshape(loc_feature, [-1, 2]) rgb_features = common_utils.normalize_embedding( tf.reshape(rgb_feature, [-1, 3])) / args.embedding_dim # Collect pixels of valid semantic classes. valid_pixels = tf.where( tf.not_equal(labels, args.ignore_label)) labels = tf.squeeze(tf.gather(labels, valid_pixels), axis=1) cluster_labels = tf.squeeze(tf.gather(cluster_labels, valid_pixels), axis=1) embedding = tf.squeeze(tf.gather(embedding, valid_pixels), axis=1) location_features = tf.squeeze( tf.gather(location_features, valid_pixels), axis=1) rgb_features = tf.squeeze(tf.gather(rgb_features, valid_pixels), axis=1) # Generate cluster labels via kmeans clustering. embedding_with_location = tf.concat( [embedding, location_features, rgb_features], 1) embedding_with_location = common_utils.normalize_embedding( embedding_with_location) cluster_labels = common_utils.kmeans_with_initial_labels( embedding_with_location, cluster_labels, args.num_clusters * args.num_clusters, args.kmeans_iterations) _, cluster_labels = tf.unique(cluster_labels) # Find pixels of majority semantic classes. select_pixels, prototype_labels = eval_utils.find_majority_label_index( labels, cluster_labels) # Calculate the prototype features. cluster_labels = tf.squeeze(tf.gather(cluster_labels, select_pixels), axis=1) embedding = tf.squeeze(tf.gather(embedding, select_pixels), axis=1) prototype_features = common_utils.calculate_prototypes_from_labels( embedding, cluster_labels) # Set up tf session and initialize variables. config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) init = tf.global_variables_initializer() sess.run(init) sess.run(tf.local_variables_initializer()) # Load weights. loader = tf.train.Saver(var_list=restore_var) if args.restore_from is not None: load(loader, sess, args.restore_from) # Start queue threads. threads = tf.train.start_queue_runners(coord=coord, sess=sess) # Create directory for saving prototypes. save_dir = os.path.join(args.save_dir, 'prototypes') if not os.path.isdir(save_dir): os.makedirs(save_dir) # Iterate over testing steps. with open(args.data_list, 'r') as listf: num_steps = len(listf.read().split('\n'))-1 pbar = tqdm(range(num_steps)) for step in pbar: image_batch_np, label_batch_np = sess.run( [image_batch, label_batch]) img_size = image_batch_np.shape padded_img_size = list(img_size) # deep copy of img_size if input_size[0] > padded_img_size[1]: padded_img_size[1] = input_size[0] if input_size[1] > padded_img_size[2]: padded_img_size[2] = input_size[1] padded_img_batch = np.zeros(padded_img_size, dtype=np.float32) img_h, img_w = img_size[1:3] padded_img_batch[:, :img_h, :img_w, :] = image_batch_np stride_h, stride_w = strides npatches_h = math.ceil(1.0*(padded_img_size[1]-input_size[0])/stride_h) + 1 npatches_w = math.ceil(1.0*(padded_img_size[2]-input_size[1])/stride_w) + 1 # Create the ending index of each patch. patch_indh = np.linspace( input_size[0], padded_img_size[1], npatches_h, dtype=np.int32) patch_indw = np.linspace( input_size[1], padded_img_size[2], npatches_w, dtype=np.int32) # Create embedding holder. padded_img_size[-1] = args.embedding_dim embedding_all_np = np.zeros(padded_img_size, dtype=np.float32) for indh in patch_indh: for indw in patch_indw: sh, eh = indh-input_size[0], indh # start & end ind of H sw, ew = indw-input_size[1], indw # start & end ind of W cropimg_batch = padded_img_batch[:, sh:eh, sw:ew, :] embedding_np = sess.run(output, feed_dict={ crop_image_batch: cropimg_batch}) embedding_all_np[:, sh:eh, sw:ew, :] += embedding_np embedding_all_np = embedding_all_np[:, :img_h, :img_w, :] loc_feature_np = common_utils.generate_location_features_np([padded_img_size[1], padded_img_size[2]]) feed_dict = {label_input: label_batch_np, embedding_input: embedding_all_np, loc_feature: loc_feature_np, rgb_feature: padded_img_batch} (batch_prototype_features_np, batch_prototype_labels_np) = sess.run( [prototype_features, prototype_labels], feed_dict=feed_dict) if step == 0: prototype_features_np = batch_prototype_features_np prototype_labels_np = batch_prototype_labels_np else: prototype_features_np = np.concatenate( [prototype_features_np, batch_prototype_features_np], axis=0) prototype_labels_np = np.concatenate( [prototype_labels_np, batch_prototype_labels_np], axis=0) print ('Total number of prototypes extracted: ', len(prototype_labels_np)) np.save( tf.gfile.Open('%s/%s.npy' % (save_dir, 'prototype_features'), mode='w'), prototype_features_np) np.save( tf.gfile.Open('%s/%s.npy' % (save_dir, 'prototype_labels'), mode='w'), prototype_labels_np) coord.request_stop() coord.join(threads)
def main(): """Create the model and start the Inference process. """ args = get_arguments() # Parse image processing arguments. input_size = parse_commastr(args.input_size) strides = parse_commastr(args.strides) assert(input_size is not None and strides is not None) h, w = input_size innet_size = (int(math.ceil(h/8)), int(math.ceil(w/8))) # Create queue coordinator. coord = tf.train.Coordinator() # Load the data reader. with tf.name_scope('create_inputs'): reader = VMFImageReader( args.data_dir, args.data_list, None, False, # No random scale. False, # No random mirror. False, # No random crop, center crop instead args.ignore_label, IMG_MEAN) image_batch = tf.expand_dims(reader.image, dim=0) label_batch = tf.expand_dims(reader.label, dim=0) cluster_label_batch = tf.expand_dims(reader.cluster_label, dim=0) loc_feature_batch = tf.expand_dims(reader.loc_feature, dim=0) # Create network and output prediction. outputs = model(image_batch, args.embedding_dim, False, True) # Grab variable names which should be restored from checkpoints. restore_var = [ v for v in tf.global_variables() if 'crop_image_batch' not in v.name] # Output predictions. output = outputs[0] output = tf.image.resize_bilinear( output, tf.shape(image_batch)[1:3,]) embedding = common_utils.normalize_embedding(output) shape = embedding.get_shape().as_list() batch_size = shape[0] labels = label_batch initial_cluster_labels = cluster_label_batch[0, :, :] location_features = tf.reshape(loc_feature_batch[0, :, :], [-1, 2]) prototype_feature_list = [] prototype_label_list = [] for bs in range(batch_size): cur_labels = tf.reshape(labels[bs], [-1]) cur_cluster_labels = tf.reshape(initial_cluster_labels, [-1]) cur_embedding = tf.reshape(embedding[bs], [-1, args.embedding_dim]) (prototype_features, prototype_labels, _) = eval_utils.extract_trained_prototypes( cur_embedding, location_features, cur_cluster_labels, args.num_clusters * args.num_clusters, args.kmeans_iterations, cur_labels, 1, args.ignore_label, 'semantic') prototype_feature_list.append(prototype_features) prototype_label_list.append(prototype_labels) prototype_features = tf.concat(prototype_feature_list, axis=0) prototype_labels = tf.concat(prototype_label_list, axis=0) # Set up tf session and initialize variables. config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) init = tf.global_variables_initializer() sess.run(init) sess.run(tf.local_variables_initializer()) # Load weights. loader = tf.train.Saver(var_list=restore_var) if args.restore_from is not None: load(loader, sess, args.restore_from) # Start queue threads. threads = tf.train.start_queue_runners(coord=coord, sess=sess) # Create directory for saving prototypes. save_dir = os.path.join(args.save_dir, 'prototypes') if not os.path.isdir(save_dir): os.makedirs(save_dir) # Iterate over testing steps. with open(args.data_list, 'r') as listf: num_steps = len(listf.read().split('\n'))-1 for step in range(num_steps): (batch_prototype_features_np, batch_prototype_labels_np) = sess.run( [prototype_features, prototype_labels]) if step == 0: prototype_features_np = batch_prototype_features_np prototype_labels_np = batch_prototype_labels_np else: prototype_features_np = np.concatenate( [prototype_features_np, batch_prototype_features_np], axis=0) prototype_labels_np = np.concatenate( [prototype_labels_np, batch_prototype_labels_np], axis=0) if (step + 1) % 100 == 0: print('Processed batches: ', (step + 1), '/', num_steps) print ('Total number of prototypes extracted: ', len(prototype_labels_np)) np.save( tf.gfile.Open('%s/%s.npy' % (save_dir, 'prototype_features'), mode='w'), prototype_features_np) np.save( tf.gfile.Open('%s/%s.npy' % (save_dir, 'prototype_labels'), mode='w'), prototype_labels_np) coord.request_stop() coord.join(threads)
def main(): """Create the model and start the Inference process. """ args = get_arguments() # Create queue coordinator. coord = tf.train.Coordinator() # Load the data reader. with tf.name_scope('create_inputs'): reader = VMFImageReader( args.data_dir, args.data_list, None, False, # No random scale. False, # No random mirror. False, # No random crop, center crop instead args.ignore_label, IMG_MEAN) image_list = reader.image_list image_batch = tf.expand_dims(reader.image, dim=0) label_batch = tf.expand_dims(reader.label, dim=0) cluster_label_batch = tf.expand_dims(reader.cluster_label, dim=0) loc_feature_batch = tf.expand_dims(reader.loc_feature, dim=0) height = reader.height width = reader.width # Create network and output prediction. outputs = model(image_batch, args.embedding_dim, False, True) # Grab variable names which should be restored from checkpoints. restore_var = [ v for v in tf.global_variables() if 'crop_image_batch' not in v.name ] # Output predictions. output = outputs[0] output = tf.image.resize_bilinear(output, tf.shape(image_batch)[1:3, ]) embedding = common_utils.normalize_embedding(output) # Prototype placeholders. prototype_features = tf.placeholder(tf.float32, shape=[None, args.embedding_dim]) prototype_labels = tf.placeholder(tf.int32) # Combine embedding with location features. embedding_with_location = tf.concat([embedding, loc_feature_batch], 3) embedding_with_location = common_utils.normalize_embedding( embedding_with_location) # Kmeans clustering. cluster_labels = common_utils.kmeans( embedding_with_location, [args.num_clusters, args.num_clusters], args.kmeans_iterations) test_prototypes = common_utils.calculate_prototypes_from_labels( embedding, cluster_labels) # Predict semantic labels. semantic_predictions, _ = eval_utils.predict_semantic_instance_labels( cluster_labels, test_prototypes, prototype_features, prototype_labels, None, args.k_in_nearest_neighbors) semantic_predictions = tf.cast(semantic_predictions, tf.uint8) semantic_predictions = tf.squeeze(semantic_predictions) # Set up tf session and initialize variables. config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) init = tf.global_variables_initializer() sess.run(init) sess.run(tf.local_variables_initializer()) # Load weights. loader = tf.train.Saver(var_list=restore_var) if args.restore_from is not None: load(loader, sess, args.restore_from) # Start queue threads. threads = tf.train.start_queue_runners(coord=coord, sess=sess) # Get colormap. map_data = scipy.io.loadmat(args.colormap) key = os.path.basename(args.colormap).replace('.mat', '') colormap = map_data[key] colormap *= 255 colormap = colormap.astype(np.uint8) # Create directory for saving predictions. pred_dir = os.path.join(args.save_dir, 'gray') color_dir = os.path.join(args.save_dir, 'color') if not os.path.isdir(pred_dir): os.makedirs(pred_dir) if not os.path.isdir(color_dir): os.makedirs(color_dir) # Iterate over testing steps. with open(args.data_list, 'r') as listf: num_steps = len(listf.read().split('\n')) - 1 # Load prototype features and labels prototype_features_np = np.load( os.path.join(args.prototype_dir, 'prototype_features.npy')) prototype_labels_np = np.load( os.path.join(args.prototype_dir, 'prototype_labels.npy')) feed_dict = { prototype_features: prototype_features_np, prototype_labels: prototype_labels_np } for step in range(num_steps): semantic_predictions_np, height_np, width_np = sess.run( [semantic_predictions, height, width], feed_dict=feed_dict) semantic_predictions_np = semantic_predictions_np[:height_np, : width_np] basename = os.path.basename(image_list[step]) basename = basename.replace('jpg', 'png') predname = os.path.join(pred_dir, basename) Image.fromarray(semantic_predictions_np, mode='L').save(predname) colorname = os.path.join(color_dir, basename) color = colormap[semantic_predictions_np] Image.fromarray(color, mode='RGB').save(colorname) if (step + 1) % 100 == 0: print('Processed batches: ', (step + 1), '/', num_steps) coord.request_stop() coord.join(threads)
def main(): """Create the model and start the Inference process. """ args = get_arguments() # Create queue coordinator. coord = tf.train.Coordinator() # Load the data reader. with tf.name_scope('create_inputs'): reader = VMFImageReader( args.data_dir, args.data_list, None, False, # No random scale. False, # No random mirror. False, # No random crop, center crop instead args.ignore_label, IMG_MEAN) image_list = reader.image_list image = reader.image cluster_label = reader.cluster_label loc_feature = reader.loc_feature height = reader.height width = reader.width # Create network and output prediction. outputs = model(tf.expand_dims(image, dim=0), args.embedding_dim, False, True) # Grab variable names which should be restored from checkpoints. restore_var = [v for v in tf.global_variables()] # Output predictions. output = outputs[0] output = tf.image.resize_bilinear(output, tf.shape(image)[:2, ]) embedding = common_utils.normalize_embedding(output) embedding = tf.squeeze(embedding, axis=0) image = image[:height, :width] embedding = tf.reshape(embedding[:height, :width], [-1, args.embedding_dim]) cluster_label = tf.reshape(cluster_label[:height, :width], [-1]) loc_feature = tf.reshape(loc_feature[:height, :width], [-1, 2]) # Prototype placeholders. prototype_features = tf.placeholder(tf.float32, shape=[None, args.embedding_dim]) prototype_labels = tf.placeholder(tf.int32) # Combine embedding with location features and kmeans embedding_with_location = tf.concat([embedding, loc_feature], 1) embedding_with_location = common_utils.normalize_embedding( embedding_with_location) cluster_label = common_utils.kmeans_with_initial_labels( embedding_with_location, cluster_label, args.num_clusters * args.num_clusters, args.kmeans_iterations) _, cluster_labels = tf.unique(cluster_label) test_prototypes = common_utils.calculate_prototypes_from_labels( embedding, cluster_labels) cluster_labels = tf.reshape(cluster_labels, [height, width]) # Predict semantic labels. similarities = tf.matmul(test_prototypes, prototype_features, transpose_b=True) _, k_predictions = tf.nn.top_k(similarities, k=args.k_in_nearest_neighbors, sorted=True) prototype_semantic_predictions = eval_utils.k_nearest_neighbors( k_predictions, prototype_labels) semantic_predictions = tf.gather(prototype_semantic_predictions, cluster_labels) # semantic_predictions = tf.squeeze(semantic_predictions) # Visualize embedding using PCA embedding = vis_utils.pca( tf.reshape(embedding, [1, height, width, args.embedding_dim])) embedding = ((embedding - tf.reduce_min(embedding)) / (tf.reduce_max(embedding) - tf.reduce_min(embedding))) embedding = tf.cast(embedding * 255, tf.uint8) embedding = tf.squeeze(embedding, axis=0) # Set up tf session and initialize variables. config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) init = tf.global_variables_initializer() sess.run(init) sess.run(tf.local_variables_initializer()) # Load weights. loader = tf.train.Saver(var_list=restore_var) if args.restore_from is not None: load(loader, sess, args.restore_from) # Start queue threads. threads = tf.train.start_queue_runners(coord=coord, sess=sess) # Get colormap. map_data = scipy.io.loadmat(args.colormap) key = os.path.basename(args.colormap).replace('.mat', '') colormap = map_data[key] colormap *= 255 colormap = colormap.astype(np.uint8) # Create directory for saving predictions. pred_dir = os.path.join(args.save_dir, 'gray') color_dir = os.path.join(args.save_dir, 'color') cluster_dir = os.path.join(args.save_dir, 'cluster') embedding_dir = os.path.join(args.save_dir, 'embedding') patch_dir = os.path.join(args.save_dir, 'test_patches') if not os.path.isdir(pred_dir): os.makedirs(pred_dir) if not os.path.isdir(color_dir): os.makedirs(color_dir) if not os.path.isdir(cluster_dir): os.makedirs(cluster_dir) if not os.path.isdir(embedding_dir): os.makedirs(embedding_dir) if not os.path.isdir(patch_dir): os.makedirs(patch_dir) # Iterate over testing steps. with open(args.data_list, 'r') as listf: num_steps = len(listf.read().split('\n')) - 1 # Load prototype features and labels prototype_features_np = np.load( os.path.join(args.prototype_dir, 'prototype_features.npy')) prototype_labels_np = np.load( os.path.join(args.prototype_dir, 'prototype_labels.npy')) feed_dict = { prototype_features: prototype_features_np, prototype_labels: prototype_labels_np } f = html_helper.open_html_for_write( os.path.join(args.save_dir, 'index.html'), 'Visualization for Segment Collaging') for step in range(num_steps): image_np, semantic_predictions_np, cluster_labels_np, embedding_np, k_predictions_np = sess.run( [ image, semantic_predictions, cluster_labels, embedding, k_predictions ], feed_dict=feed_dict) imgname = os.path.basename(image_list[step]) basename = imgname.replace('jpg', 'png') predname = os.path.join(pred_dir, basename) Image.fromarray(semantic_predictions_np, mode='L').save(predname) colorname = os.path.join(color_dir, basename) color = colormap[semantic_predictions_np] Image.fromarray(color, mode='RGB').save(colorname) clustername = os.path.join(cluster_dir, basename) cluster = colormap[cluster_labels_np] Image.fromarray(cluster, mode='RGB').save(clustername) embeddingname = os.path.join(embedding_dir, basename) Image.fromarray(embedding_np, mode='RGB').save(embeddingname) image_np = (image_np + IMG_MEAN).astype(np.uint8) for i in range(np.max(cluster_labels_np) + 1): image_temp = copy.deepcopy(image_np) image_temp[cluster_labels_np != i] = 0 coords = np.where(cluster_labels_np == i) crop = image_temp[np.min(coords[0]):np.max(coords[0]), np.min(coords[1]):np.max(coords[1])] scipy.misc.imsave( patch_dir + '/' + basename + str(i).zfill(3) + '.png', crop) html_helper.write_vmf_to_html( f, './images/' + imgname, './labels/' + basename, './color/' + basename, './cluster/' + basename, './embedding/' + basename, './test_patches/' + basename, './patches/', k_predictions_np) if (step + 1) % 100 == 0: print('Processed batches: ', (step + 1), '/', num_steps) html_helper.close_html(f) coord.request_stop() coord.join(threads)
def main(): """Create the model and start the inference process. """ args = get_arguments() # Parse image processing arguments. input_size = parse_commastr(args.input_size) strides = parse_commastr(args.strides) assert (input_size is not None and strides is not None) h, w = input_size innet_size = (int(math.ceil(h / 8)), int(math.ceil(w / 8))) # Create queue coordinator. coord = tf.train.Coordinator() # Load the data reader. with tf.name_scope('create_inputs'): reader = VMFImageReader( args.data_dir, args.data_list, None, False, # No random scale. False, # No random mirror. False, # No random crop, center crop instead args.ignore_label, IMG_MEAN) image = reader.image image_list = reader.image_list image_batch = tf.expand_dims(image, dim=0) # Create multi-scale augmented datas. rescale_image_batches = [] is_flipped = [] scales = [0.5, 0.75, 1, 1.25, 1.5, 1.75] if args.scale_aug else [1] for scale in scales: h_new = tf.to_int32( tf.multiply(tf.to_float(tf.shape(image_batch)[1]), scale)) w_new = tf.to_int32( tf.multiply(tf.to_float(tf.shape(image_batch)[2]), scale)) new_shape = tf.stack([h_new, w_new]) new_image_batch = tf.image.resize_images(image_batch, new_shape) rescale_image_batches.append(new_image_batch) is_flipped.append(False) # Create horizontally flipped augmented datas. if args.flip_aug: for i in range(len(scales)): img = rescale_image_batches[i] is_flip = is_flipped[i] img = tf.squeeze(img, axis=0) flip_img = tf.image.flip_left_right(img) flip_img = tf.expand_dims(flip_img, axis=0) rescale_image_batches.append(flip_img) is_flipped.append(True) # Create input tensor to the Network crop_image_batch = tf.placeholder( name='crop_image_batch', shape=[1, input_size[0], input_size[1], 3], dtype=tf.float32) # Create network. outputs = model(crop_image_batch, args.embedding_dim, False, True) # Grab variable names which should be restored from checkpoints. restore_var = [ v for v in tf.global_variables() if 'crop_image_batch' not in v.name ] # Output predictions. output = outputs[0] output = tf.image.resize_bilinear(output, [input_size[0], input_size[1]]) embedding_input = tf.placeholder(tf.float32, shape=[1, None, None, args.embedding_dim]) embedding = common_utils.normalize_embedding(embedding_input) loc_feature = tf.placeholder(tf.float32, shape=[1, None, None, 2]) # Prototype placeholders. prototype_features = tf.placeholder(tf.float32, shape=[None, args.embedding_dim]) prototype_labels = tf.placeholder(tf.int32) # Combine embedding with location features and kmeans shape = tf.shape(embedding) #.get_shape().as_list() # loc_feature = tf.expand_dims( # common_utils.generate_location_features([shape[1], shape[2]], 'float'), 0) embedding_with_location = tf.concat([embedding, loc_feature], 3) embedding_with_location = common_utils.normalize_embedding( embedding_with_location) cluster_labels = common_utils.kmeans( embedding_with_location, [args.num_clusters, args.num_clusters], args.kmeans_iterations) _, cluster_labels = tf.unique(tf.reshape(cluster_labels, [-1])) test_prototypes = common_utils.calculate_prototypes_from_labels( embedding, cluster_labels) # Predict semantic labels. similarities = tf.matmul(test_prototypes, prototype_features, transpose_b=True) _, k_predictions = tf.nn.top_k(similarities, k=args.k_in_nearest_neighbors, sorted=True) k_predictions = tf.gather(prototype_labels, k_predictions) k_predictions = tf.gather(k_predictions, cluster_labels) k_predictions = tf.reshape( k_predictions, [shape[0], shape[1], shape[2], args.k_in_nearest_neighbors]) # Set up tf session and initialize variables. config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) init = tf.global_variables_initializer() sess.run(init) sess.run(tf.local_variables_initializer()) # Load weights. loader = tf.train.Saver(var_list=restore_var) if args.restore_from is not None: load(loader, sess, args.restore_from) # Start queue threads. threads = tf.train.start_queue_runners(coord=coord, sess=sess) # Get colormap. map_data = scipy.io.loadmat(args.colormap) key = os.path.basename(args.colormap).replace('.mat', '') colormap = map_data[key] colormap *= 255 colormap = colormap.astype(np.uint8) # Create directory for saving predictions. pred_dir = os.path.join(args.save_dir, 'gray') color_dir = os.path.join(args.save_dir, 'color') if not os.path.isdir(pred_dir): os.makedirs(pred_dir) if not os.path.isdir(color_dir): os.makedirs(color_dir) # Iterate over testing steps. with open(args.data_list, 'r') as listf: num_steps = len(listf.read().split('\n')) - 1 # Load prototype features and labels prototype_features_np = np.load( os.path.join(args.prototype_dir, 'prototype_features.npy')) prototype_labels_np = np.load( os.path.join(args.prototype_dir, 'prototype_labels.npy')) feed_dict = { prototype_features: prototype_features_np, prototype_labels: prototype_labels_np } for step in range(num_steps): rescale_img_batches = sess.run(rescale_image_batches) # Final segmentation results (average across multiple scales). scale_ind = 2 if args.scale_aug else 0 final_lab_size = list(rescale_img_batches[scale_ind].shape[1:]) final_lab_size[-1] = args.num_classes final_lab_batch = np.zeros(final_lab_size) # Iterate over multiple scales. for img_batch, is_flip in zip(rescale_img_batches, is_flipped): img_size = img_batch.shape padimg_size = list(img_size) # deep copy of img_size padimg_h, padimg_w = padimg_size[1:3] input_h, input_w = input_size if input_h > padimg_h: padimg_h = input_h if input_w > padimg_w: padimg_w = input_w # Update padded image size. padimg_size[1] = padimg_h padimg_size[2] = padimg_w padimg_batch = np.zeros(padimg_size, dtype=np.float32) img_h, img_w = img_size[1:3] padimg_batch[:, :img_h, :img_w, :] = img_batch stride_h, stride_w = strides npatches_h = math.ceil(1.0 * (padimg_h - input_h) / stride_h) + 1 npatches_w = math.ceil(1.0 * (padimg_w - input_w) / stride_w) + 1 # Create padded prediction array. pred_size = list(padimg_size) pred_size[-1] = args.num_classes predictions_np = np.zeros(pred_size, dtype=np.int32) # Create the ending index of each patch. patch_indh = np.linspace(input_h, padimg_h, npatches_h, dtype=np.int32) patch_indw = np.linspace(input_w, padimg_w, npatches_w, dtype=np.int32) pred_size[-1] = args.embedding_dim embedding_all_np = np.zeros(pred_size, dtype=np.float32) for indh in patch_indh: for indw in patch_indw: sh, eh = indh - input_h, indh # start & end ind of H sw, ew = indw - input_w, indw # start & end ind of W cropimg_batch = padimg_batch[:, sh:eh, sw:ew, :] # feed_dict[crop_image_batch] = cropimg_batch embedding_np = sess.run( output, feed_dict={crop_image_batch: cropimg_batch}) embedding_all_np[:, sh:eh, sw:ew, :] += embedding_np loc_feature_np = common_utils.generate_location_features_np( [padimg_h, padimg_w]) feed_dict[embedding_input] = embedding_all_np feed_dict[loc_feature] = loc_feature_np k_predictions_np = sess.run(k_predictions, feed_dict=feed_dict) for c in range(args.num_classes): predictions_np[:, :, :, c] += np.sum( (k_predictions_np == c).astype(np.int), axis=3) predictions_np = predictions_np[0, :img_h, :img_w, :] lab_batch = predictions_np.astype(np.float32) # lab_batch = np.argmax(pred_batch, axis=2).astype(np.float32) # Rescale prediction back to original resolution. lab_batch = cv2.resize(lab_batch, (final_lab_size[1], final_lab_size[0]), interpolation=cv2.INTER_LINEAR) if is_flip: # Flipped prediction back to original orientation. lab_batch = lab_batch[:, ::-1, :] final_lab_batch += lab_batch final_lab_ind = np.argmax(final_lab_batch, axis=-1) final_lab_ind = final_lab_ind.astype(np.uint8) basename = os.path.basename(image_list[step]) basename = basename.replace('jpg', 'png') predname = os.path.join(pred_dir, basename) Image.fromarray(final_lab_ind, mode='L').save(predname) colorname = os.path.join(color_dir, basename) color = colormap[final_lab_ind] Image.fromarray(color, mode='RGB').save(colorname) if (step + 1) % 100 == 0: print('Processed batches: ', (step + 1), '/', num_steps) coord.request_stop() coord.join(threads)