def main(): """Create the model and start training. """ # Read CL arguments and snapshot the arguments into text file. args = get_arguments() utils.general.snapshot_arg(args) # The segmentation network is stride 8 by default. h, w = map(int, args.input_size.split(',')) input_size = (h, w) innet_size = (int(math.ceil(h / 8)), int(math.ceil(w / 8))) # Initialize the random seed. tf.set_random_seed(args.random_seed) # Create queue coordinator. coord = tf.train.Coordinator() # current step step_ph = tf.placeholder(dtype=tf.float32, shape=()) # Set up tf session and initialize variables. config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) # Load the data reader. with tf.device('/cpu:0'): with tf.name_scope('create_inputs'): reader = ImageReader( args.data_dir, args.data_list, input_size, args.random_scale, args.random_mirror, args.random_crop, args.ignore_label, IMG_MEAN) image_batch, label_batch = reader.dequeue(args.batch_size) ''' image_batch => (N,H,W,C=3) label_batch => (N,H,W,1) ''' # Shrink labels to the size of the network output. labels = tf.image.resize_nearest_neighbor( label_batch, innet_size, name='label_shrink') labels_flat = tf.reshape(labels, [-1, ]) # Ignore the location where the label value is larger than args.num_classes. not_ignore_pixel = tf.less_equal(labels_flat, args.num_classes - 1) # Extract the indices of pixel where the gradients are propogated. pixel_inds = tf.squeeze(tf.where(not_ignore_pixel), 1) # Create network and predictions. outputs = model(image_batch, args.num_classes, args.is_training, args.use_global_status) # Grab variable names which should be restored from checkpoints. restore_var = [ v for v in tf.global_variables() if 'block5' not in v.name or not args.not_restore_classifier ] # Sum the losses from output branches. labels_gather = tf.to_int32(tf.gather(labels_flat, pixel_inds)) seg_losses = [] aff_losses = [] for i, output in enumerate(outputs): # outputs = (1,N,H,W,C) # Define softmax loss. tf.Print output_2d = tf.reshape(output, [-1, args.num_classes]) output_gather = tf.gather(output_2d, pixel_inds) seg_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=output_gather, labels=labels_gather) seg_loss = tf.reduce_mean(seg_loss) seg_losses.append(seg_loss) # Define AFF loss. prob = tf.nn.softmax(output, axis=-1) edge_loss, not_edge_loss = lossx.affinity_loss(labels, prob, args.num_classes, args.kld_margin) # Apply exponential decay to the AFF loss. dec = tf.pow(10.0, -step_ph / args.num_steps) aff_loss = tf.reduce_mean(edge_loss) * args.kld_lambda_1 * dec aff_loss += tf.reduce_mean(not_edge_loss) * args.kld_lambda_2 * dec aff_losses.append(aff_loss) # Define weight regularization loss. w = args.weight_decay l2_losses = [w * tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'weights' in v.name] # Sum all loss terms. mean_seg_loss = tf.add_n(seg_losses) mean_aff_loss = tf.add_n(aff_losses) mean_l2_loss = tf.add_n(l2_losses) reduced_loss = mean_seg_loss + mean_l2_loss + mean_aff_loss # Grab variable names which are used for training. all_trainable = tf.trainable_variables() fc_trainable = [v for v in all_trainable if 'block5' in v.name] # lr*10 base_trainable = [v for v in all_trainable if 'block5' not in v.name] # lr*1 # Computes gradients per iteration. grads = tf.gradients(reduced_loss, base_trainable + fc_trainable) grads_base = grads[:len(base_trainable)] grads_fc = grads[len(base_trainable):] # Define optimisation parameters. base_lr = tf.constant(args.learning_rate) learning_rate = tf.scalar_mul( base_lr, tf.pow((1 - step_ph / args.num_steps), args.power)) opt_base = tf.train.MomentumOptimizer(learning_rate * 1.0, args.momentum) opt_fc = tf.train.MomentumOptimizer(learning_rate * 10.0, args.momentum) # Define tensorflow operations which apply gradients to update variables. train_op_base = opt_base.apply_gradients( zip(grads_base, base_trainable)) train_op_fc = opt_fc.apply_gradients( zip(grads_fc, fc_trainable)) train_op = tf.group(train_op_base, train_op_fc) # Process for visualisation. with tf.device('/cpu:0'): # Image summary for input image, ground-truth label and prediction. output_vis = tf.image.resize_nearest_neighbor( outputs[-1], tf.shape(image_batch)[1:3, ]) output_vis = tf.argmax(output_vis, axis=3) output_vis = tf.expand_dims(output_vis, dim=3) output_vis = tf.cast(output_vis, dtype=tf.uint8) labels_vis = tf.cast(label_batch, dtype=tf.uint8) in_summary = tf.py_func( utils.general.inv_preprocess, [image_batch, IMG_MEAN], tf.uint8) gt_summary = tf.py_func( utils.general.decode_labels, [labels_vis, args.num_classes], tf.uint8) out_summary = tf.py_func( utils.general.decode_labels, [output_vis, args.num_classes], tf.uint8) # Concatenate image summaries in a row. total_summary = tf.summary.image( 'images', tf.concat(axis=2, values=[in_summary, gt_summary, out_summary]), max_outputs=args.batch_size) # Scalar summary for different loss terms. seg_loss_summary = tf.summary.scalar( 'seg_loss', mean_seg_loss) aff_loss_summary = tf.summary.scalar( 'aff_loss', mean_aff_loss) total_summary = tf.summary.merge_all() summary_writer = tf.summary.FileWriter(args.snapshot_dir, graph=tf.get_default_graph()) # # Set up tf session and initialize variables. # config = tf.ConfigProto() # config.gpu_options.allow_growth = True # sess = tf.Session(config=config) init = tf.global_variables_initializer() sess.run(init) # Saver for storing checkpoints of the model. saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10) # Load variables if the checkpoint is provided. if args.restore_from is not None: loader = tf.train.Saver(var_list=restore_var) load(loader, sess, args.restore_from) # Start queue threads. threads = tf.train.start_queue_runners(coord=coord, sess=sess) # Iterate over training steps. pbar = tqdm(range(args.num_steps)) for step in pbar: start_time = time.time() feed_dict = {step_ph: step} step_loss = 0 for it in range(args.iter_size): # Update summary periodically. if it == args.iter_size - 1 and step % args.update_tb_every == 0: sess_outs = [reduced_loss, total_summary, train_op] loss_value, summary, _ = sess.run(sess_outs, feed_dict=feed_dict) summary_writer.add_summary(summary, step) else: sess_outs = [reduced_loss, train_op] loss_value, _ = sess.run(sess_outs, feed_dict=feed_dict) step_loss += loss_value step_loss /= args.iter_size lr = sess.run(learning_rate, feed_dict=feed_dict) # Save trained model periodically. if step % args.save_pred_every == 0 and step > 0: save(saver, sess, args.snapshot_dir, step) duration = time.time() - start_time desc = 'loss = {:.3f}, lr = {:.6f}'.format(step_loss, lr) pbar.set_description(desc) coord.request_stop() coord.join(threads)
def main(): """Create the model and start the inference process. """ args = get_arguments() # Parse image processing arguments. input_size = parse_commastr(args.input_size) strides = parse_commastr(args.strides) assert (input_size is not None and strides is not None) h, w = input_size innet_size = (int(math.ceil(h / 8)), int(math.ceil(w / 8))) # Create queue coordinator. coord = tf.train.Coordinator() # Load the data reader. with tf.name_scope('create_inputs'): reader = ImageReader( args.data_dir, args.data_list, None, False, # No random scale. False, # No random mirror. False, # No random crop, center crop instead args.ignore_label, IMG_MEAN) image = reader.image image_list = reader.image_list image_batch = tf.expand_dims(image, dim=0) # Create multi-scale augmented datas. rescale_image_batches = [] is_flipped = [] scales = [0.5, 0.75, 1, 1.25, 1.5, 1.75] if args.scale_aug else [1] for scale in scales: h_new = tf.to_int32( tf.multiply(tf.to_float(tf.shape(image_batch)[1]), scale)) w_new = tf.to_int32( tf.multiply(tf.to_float(tf.shape(image_batch)[2]), scale)) new_shape = tf.stack([h_new, w_new]) new_image_batch = tf.image.resize_images(image_batch, new_shape) rescale_image_batches.append(new_image_batch) is_flipped.append(False) # Create horizontally flipped augmented datas. if args.flip_aug: for i in range(len(scales)): img = rescale_image_batches[i] is_flip = is_flipped[i] img = tf.squeeze(img, axis=0) flip_img = tf.image.flip_left_right(img) flip_img = tf.expand_dims(flip_img, axis=0) rescale_image_batches.append(flip_img) is_flipped.append(True) # Create input tensor to the Network crop_image_batch = tf.placeholder( name='crop_image_batch', shape=[1, input_size[0], input_size[1], 3], dtype=tf.float32) # Create network. outputs = model(crop_image_batch, args.num_classes, False, True) # Grab variable names which should be restored from checkpoints. restore_var = [ v for v in tf.global_variables() if 'crop_image_batch' not in v.name ] # Output predictions. output = outputs[-1] output = tf.image.resize_bilinear(output, tf.shape(crop_image_batch)[1:3, ]) output = tf.nn.softmax(output, dim=3) # Set up tf session and initialize variables. config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) init = tf.global_variables_initializer() sess.run(init) sess.run(tf.local_variables_initializer()) # Load weights. loader = tf.train.Saver(var_list=restore_var) if args.restore_from is not None: load(loader, sess, args.restore_from) # Start queue threads. threads = tf.train.start_queue_runners(coord=coord, sess=sess) # Get colormap. map_data = scipy.io.loadmat(args.colormap) key = os.path.basename(args.colormap).replace('.mat', '') colormap = map_data[key] colormap *= 255 colormap = colormap.astype(np.uint8) # Create directory for saving predictions. pred_dir = os.path.join(args.save_dir, 'gray') color_dir = os.path.join(args.save_dir, 'color') if not os.path.isdir(pred_dir): os.makedirs(pred_dir) if not os.path.isdir(color_dir): os.makedirs(color_dir) # Iterate over testing steps. with open(args.data_list, 'r') as listf: num_steps = len(listf.read().split('\n')) - 1 for step in range(num_steps): rescale_img_batches = sess.run(rescale_image_batches) # Final segmentation results (average across multiple scales). scale_ind = 2 if args.scale_aug else 0 final_lab_size = list(rescale_img_batches[scale_ind].shape[1:]) final_lab_size[-1] = args.num_classes final_lab_batch = np.zeros(final_lab_size) # Iterate over multiple scales. for img_batch, is_flip in zip(rescale_img_batches, is_flipped): img_size = img_batch.shape padimg_size = list(img_size) # deep copy of img_size padimg_h, padimg_w = padimg_size[1:3] input_h, input_w = input_size if input_h > padimg_h: padimg_h = input_h if input_w > padimg_w: padimg_w = input_w # Update padded image size. padimg_size[1] = padimg_h padimg_size[2] = padimg_w padimg_batch = np.zeros(padimg_size, dtype=np.float32) img_h, img_w = img_size[1:3] padimg_batch[:, :img_h, :img_w, :] = img_batch # Create padded label array. lab_size = list(padimg_size) lab_size[-1] = args.num_classes lab_batch = np.zeros(lab_size, dtype=np.float32) lab_batch.fill(args.ignore_label) num_batch = np.zeros(lab_size[:-1], dtype=np.float32) stride_h, stride_w = strides npatches_h = math.ceil(1.0 * (padimg_h - input_h) / stride_h) + 1 npatches_w = math.ceil(1.0 * (padimg_w - input_w) / stride_w) + 1 # Create the ending index of each patch. patch_indh = np.linspace(input_h, padimg_h, npatches_h, dtype=np.int32) patch_indw = np.linspace(input_w, padimg_w, npatches_w, dtype=np.int32) for indh in patch_indh: for indw in patch_indw: sh, eh = indh - input_h, indh # start&end ind of H sw, ew = indw - input_w, indw # start&end ind of W cropimg_batch = padimg_batch[:, sh:eh, sw:ew, :] feed_dict = {crop_image_batch: cropimg_batch} out = sess.run(output, feed_dict=feed_dict) lab_batch[:, sh:eh, sw:ew, :] += out num_batch[:, sh:eh, sw:ew] += 1 lab_batch /= num_batch[..., np.newaxis] lab_batch = lab_batch[0, :img_h, :img_w, :] # Rescale prediction back to original resolution. lab_batch = cv2.resize(lab_batch, (final_lab_size[1], final_lab_size[0]), interpolation=cv2.INTER_LINEAR) if is_flip: # Flipped prediction back to original orientation. lab_batch = lab_batch[:, ::-1, :] final_lab_batch += lab_batch final_lab_ind = np.argmax(final_lab_batch, axis=-1) final_lab_ind = final_lab_ind.astype(np.uint8) basename = os.path.basename(image_list[step]) basename = basename.replace('jpg', 'png') predname = os.path.join(pred_dir, basename) Image.fromarray(final_lab_ind, mode='L').save(predname) colorname = os.path.join(color_dir, basename) color = colormap[final_lab_ind] Image.fromarray(color, mode='RGB').save(colorname) coord.request_stop() coord.join(threads)
def main(): """Create the model and start the Inference process. """ args = get_arguments() # Parse image processing arguments. input_size = parse_commastr(args.input_size) strides = parse_commastr(args.strides) assert(input_size is not None and strides is not None) h, w = input_size innet_size = (int(math.ceil(h/8)), int(math.ceil(w/8))) # Create queue coordinator. coord = tf.train.Coordinator() # Load the data reader. with tf.name_scope('create_inputs'): reader = VMFImageReader( args.data_dir, args.data_list, None, False, # No random scale. False, # No random mirror. False, # No random crop, center crop instead args.ignore_label, IMG_MEAN) image = reader.image label = reader.label image_list = reader.image_list image_batch = tf.expand_dims(image, dim=0) label_batch = tf.expand_dims(label, dim=0) # Create input tensor to the Network crop_image_batch = tf.placeholder( name='crop_image_batch', shape=[1,input_size[0],input_size[1],3], dtype=tf.float32) # Create network and output prediction. outputs = model(crop_image_batch, args.embedding_dim, False, True) # Grab variable names which should be restored from checkpoints. restore_var = [ v for v in tf.global_variables() if 'crop_image_batch' not in v.name] # Output predictions. output = outputs[0] output = tf.image.resize_bilinear( output, [input_size[0], input_size[1]]) # Input full-sized embedding label_input = tf.placeholder( tf.int32, shape=[1, None, None, 1]) embedding_input = tf.placeholder( tf.float32, shape=[1, None, None, args.embedding_dim]) embedding = common_utils.normalize_embedding(embedding_input) loc_feature = tf.placeholder( tf.float32, shape=[1, None, None, 2]) rgb_feature = tf.placeholder( tf.float32, shape=[1, None, None, 3]) # Combine embedding with location features and kmeans shape = tf.shape(embedding) cluster_labels = common_utils.initialize_cluster_labels( [args.num_clusters, args.num_clusters], [shape[1], shape[2]]) embedding = tf.reshape(embedding, [-1, args.embedding_dim]) labels = tf.reshape(label_input, [-1]) cluster_labels = tf.reshape(cluster_labels, [-1]) location_features = tf.reshape(loc_feature, [-1, 2]) rgb_features = common_utils.normalize_embedding( tf.reshape(rgb_feature, [-1, 3])) / args.embedding_dim # Collect pixels of valid semantic classes. valid_pixels = tf.where( tf.not_equal(labels, args.ignore_label)) labels = tf.squeeze(tf.gather(labels, valid_pixels), axis=1) cluster_labels = tf.squeeze(tf.gather(cluster_labels, valid_pixels), axis=1) embedding = tf.squeeze(tf.gather(embedding, valid_pixels), axis=1) location_features = tf.squeeze( tf.gather(location_features, valid_pixels), axis=1) rgb_features = tf.squeeze(tf.gather(rgb_features, valid_pixels), axis=1) # Generate cluster labels via kmeans clustering. embedding_with_location = tf.concat( [embedding, location_features, rgb_features], 1) embedding_with_location = common_utils.normalize_embedding( embedding_with_location) cluster_labels = common_utils.kmeans_with_initial_labels( embedding_with_location, cluster_labels, args.num_clusters * args.num_clusters, args.kmeans_iterations) _, cluster_labels = tf.unique(cluster_labels) # Find pixels of majority semantic classes. select_pixels, prototype_labels = eval_utils.find_majority_label_index( labels, cluster_labels) # Calculate the prototype features. cluster_labels = tf.squeeze(tf.gather(cluster_labels, select_pixels), axis=1) embedding = tf.squeeze(tf.gather(embedding, select_pixels), axis=1) prototype_features = common_utils.calculate_prototypes_from_labels( embedding, cluster_labels) # Set up tf session and initialize variables. config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) init = tf.global_variables_initializer() sess.run(init) sess.run(tf.local_variables_initializer()) # Load weights. loader = tf.train.Saver(var_list=restore_var) if args.restore_from is not None: load(loader, sess, args.restore_from) # Start queue threads. threads = tf.train.start_queue_runners(coord=coord, sess=sess) # Create directory for saving prototypes. save_dir = os.path.join(args.save_dir, 'prototypes') if not os.path.isdir(save_dir): os.makedirs(save_dir) # Iterate over testing steps. with open(args.data_list, 'r') as listf: num_steps = len(listf.read().split('\n'))-1 pbar = tqdm(range(num_steps)) for step in pbar: image_batch_np, label_batch_np = sess.run( [image_batch, label_batch]) img_size = image_batch_np.shape padded_img_size = list(img_size) # deep copy of img_size if input_size[0] > padded_img_size[1]: padded_img_size[1] = input_size[0] if input_size[1] > padded_img_size[2]: padded_img_size[2] = input_size[1] padded_img_batch = np.zeros(padded_img_size, dtype=np.float32) img_h, img_w = img_size[1:3] padded_img_batch[:, :img_h, :img_w, :] = image_batch_np stride_h, stride_w = strides npatches_h = math.ceil(1.0*(padded_img_size[1]-input_size[0])/stride_h) + 1 npatches_w = math.ceil(1.0*(padded_img_size[2]-input_size[1])/stride_w) + 1 # Create the ending index of each patch. patch_indh = np.linspace( input_size[0], padded_img_size[1], npatches_h, dtype=np.int32) patch_indw = np.linspace( input_size[1], padded_img_size[2], npatches_w, dtype=np.int32) # Create embedding holder. padded_img_size[-1] = args.embedding_dim embedding_all_np = np.zeros(padded_img_size, dtype=np.float32) for indh in patch_indh: for indw in patch_indw: sh, eh = indh-input_size[0], indh # start & end ind of H sw, ew = indw-input_size[1], indw # start & end ind of W cropimg_batch = padded_img_batch[:, sh:eh, sw:ew, :] embedding_np = sess.run(output, feed_dict={ crop_image_batch: cropimg_batch}) embedding_all_np[:, sh:eh, sw:ew, :] += embedding_np embedding_all_np = embedding_all_np[:, :img_h, :img_w, :] loc_feature_np = common_utils.generate_location_features_np([padded_img_size[1], padded_img_size[2]]) feed_dict = {label_input: label_batch_np, embedding_input: embedding_all_np, loc_feature: loc_feature_np, rgb_feature: padded_img_batch} (batch_prototype_features_np, batch_prototype_labels_np) = sess.run( [prototype_features, prototype_labels], feed_dict=feed_dict) if step == 0: prototype_features_np = batch_prototype_features_np prototype_labels_np = batch_prototype_labels_np else: prototype_features_np = np.concatenate( [prototype_features_np, batch_prototype_features_np], axis=0) prototype_labels_np = np.concatenate( [prototype_labels_np, batch_prototype_labels_np], axis=0) print ('Total number of prototypes extracted: ', len(prototype_labels_np)) np.save( tf.gfile.Open('%s/%s.npy' % (save_dir, 'prototype_features'), mode='w'), prototype_features_np) np.save( tf.gfile.Open('%s/%s.npy' % (save_dir, 'prototype_labels'), mode='w'), prototype_labels_np) coord.request_stop() coord.join(threads)
def main(): """Create the model and start the Inference process. """ args = get_arguments() # Create queue coordinator. coord = tf.train.Coordinator() # Load the data reader. with tf.name_scope('create_inputs'): reader = VMFImageReader( args.data_dir, args.data_list, None, False, # No random scale. False, # No random mirror. False, # No random crop, center crop instead args.ignore_label, IMG_MEAN) image_list = reader.image_list image = reader.image cluster_label = reader.cluster_label loc_feature = reader.loc_feature height = reader.height width = reader.width # Create network and output prediction. outputs = model(tf.expand_dims(image, dim=0), args.embedding_dim, False, True) # Grab variable names which should be restored from checkpoints. restore_var = [v for v in tf.global_variables()] # Output predictions. output = outputs[0] output = tf.image.resize_bilinear(output, tf.shape(image)[:2, ]) embedding = common_utils.normalize_embedding(output) embedding = tf.squeeze(embedding, axis=0) image = image[:height, :width] embedding = tf.reshape(embedding[:height, :width], [-1, args.embedding_dim]) cluster_label = tf.reshape(cluster_label[:height, :width], [-1]) loc_feature = tf.reshape(loc_feature[:height, :width], [-1, 2]) # Prototype placeholders. prototype_features = tf.placeholder(tf.float32, shape=[None, args.embedding_dim]) prototype_labels = tf.placeholder(tf.int32) # Combine embedding with location features and kmeans embedding_with_location = tf.concat([embedding, loc_feature], 1) embedding_with_location = common_utils.normalize_embedding( embedding_with_location) cluster_label = common_utils.kmeans_with_initial_labels( embedding_with_location, cluster_label, args.num_clusters * args.num_clusters, args.kmeans_iterations) _, cluster_labels = tf.unique(cluster_label) test_prototypes = common_utils.calculate_prototypes_from_labels( embedding, cluster_labels) cluster_labels = tf.reshape(cluster_labels, [height, width]) # Predict semantic labels. similarities = tf.matmul(test_prototypes, prototype_features, transpose_b=True) _, k_predictions = tf.nn.top_k(similarities, k=args.k_in_nearest_neighbors, sorted=True) prototype_semantic_predictions = eval_utils.k_nearest_neighbors( k_predictions, prototype_labels) semantic_predictions = tf.gather(prototype_semantic_predictions, cluster_labels) # semantic_predictions = tf.squeeze(semantic_predictions) # Visualize embedding using PCA embedding = vis_utils.pca( tf.reshape(embedding, [1, height, width, args.embedding_dim])) embedding = ((embedding - tf.reduce_min(embedding)) / (tf.reduce_max(embedding) - tf.reduce_min(embedding))) embedding = tf.cast(embedding * 255, tf.uint8) embedding = tf.squeeze(embedding, axis=0) # Set up tf session and initialize variables. config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) init = tf.global_variables_initializer() sess.run(init) sess.run(tf.local_variables_initializer()) # Load weights. loader = tf.train.Saver(var_list=restore_var) if args.restore_from is not None: load(loader, sess, args.restore_from) # Start queue threads. threads = tf.train.start_queue_runners(coord=coord, sess=sess) # Get colormap. map_data = scipy.io.loadmat(args.colormap) key = os.path.basename(args.colormap).replace('.mat', '') colormap = map_data[key] colormap *= 255 colormap = colormap.astype(np.uint8) # Create directory for saving predictions. pred_dir = os.path.join(args.save_dir, 'gray') color_dir = os.path.join(args.save_dir, 'color') cluster_dir = os.path.join(args.save_dir, 'cluster') embedding_dir = os.path.join(args.save_dir, 'embedding') patch_dir = os.path.join(args.save_dir, 'test_patches') if not os.path.isdir(pred_dir): os.makedirs(pred_dir) if not os.path.isdir(color_dir): os.makedirs(color_dir) if not os.path.isdir(cluster_dir): os.makedirs(cluster_dir) if not os.path.isdir(embedding_dir): os.makedirs(embedding_dir) if not os.path.isdir(patch_dir): os.makedirs(patch_dir) # Iterate over testing steps. with open(args.data_list, 'r') as listf: num_steps = len(listf.read().split('\n')) - 1 # Load prototype features and labels prototype_features_np = np.load( os.path.join(args.prototype_dir, 'prototype_features.npy')) prototype_labels_np = np.load( os.path.join(args.prototype_dir, 'prototype_labels.npy')) feed_dict = { prototype_features: prototype_features_np, prototype_labels: prototype_labels_np } f = html_helper.open_html_for_write( os.path.join(args.save_dir, 'index.html'), 'Visualization for Segment Collaging') for step in range(num_steps): image_np, semantic_predictions_np, cluster_labels_np, embedding_np, k_predictions_np = sess.run( [ image, semantic_predictions, cluster_labels, embedding, k_predictions ], feed_dict=feed_dict) imgname = os.path.basename(image_list[step]) basename = imgname.replace('jpg', 'png') predname = os.path.join(pred_dir, basename) Image.fromarray(semantic_predictions_np, mode='L').save(predname) colorname = os.path.join(color_dir, basename) color = colormap[semantic_predictions_np] Image.fromarray(color, mode='RGB').save(colorname) clustername = os.path.join(cluster_dir, basename) cluster = colormap[cluster_labels_np] Image.fromarray(cluster, mode='RGB').save(clustername) embeddingname = os.path.join(embedding_dir, basename) Image.fromarray(embedding_np, mode='RGB').save(embeddingname) image_np = (image_np + IMG_MEAN).astype(np.uint8) for i in range(np.max(cluster_labels_np) + 1): image_temp = copy.deepcopy(image_np) image_temp[cluster_labels_np != i] = 0 coords = np.where(cluster_labels_np == i) crop = image_temp[np.min(coords[0]):np.max(coords[0]), np.min(coords[1]):np.max(coords[1])] scipy.misc.imsave( patch_dir + '/' + basename + str(i).zfill(3) + '.png', crop) html_helper.write_vmf_to_html( f, './images/' + imgname, './labels/' + basename, './color/' + basename, './cluster/' + basename, './embedding/' + basename, './test_patches/' + basename, './patches/', k_predictions_np) if (step + 1) % 100 == 0: print('Processed batches: ', (step + 1), '/', num_steps) html_helper.close_html(f) coord.request_stop() coord.join(threads)
def main(): """Create the model and start the Inference process. """ args = get_arguments() # Parse image processing arguments. input_size = parse_commastr(args.input_size) strides = parse_commastr(args.strides) assert(input_size is not None and strides is not None) h, w = input_size innet_size = (int(math.ceil(h/8)), int(math.ceil(w/8))) # Create queue coordinator. coord = tf.train.Coordinator() # Load the data reader. with tf.name_scope('create_inputs'): reader = VMFImageReader( args.data_dir, args.data_list, None, False, # No random scale. False, # No random mirror. False, # No random crop, center crop instead args.ignore_label, IMG_MEAN) image_batch = tf.expand_dims(reader.image, dim=0) label_batch = tf.expand_dims(reader.label, dim=0) cluster_label_batch = tf.expand_dims(reader.cluster_label, dim=0) loc_feature_batch = tf.expand_dims(reader.loc_feature, dim=0) # Create network and output prediction. outputs = model(image_batch, args.embedding_dim, False, True) # Grab variable names which should be restored from checkpoints. restore_var = [ v for v in tf.global_variables() if 'crop_image_batch' not in v.name] # Output predictions. output = outputs[0] output = tf.image.resize_bilinear( output, tf.shape(image_batch)[1:3,]) embedding = common_utils.normalize_embedding(output) shape = embedding.get_shape().as_list() batch_size = shape[0] labels = label_batch initial_cluster_labels = cluster_label_batch[0, :, :] location_features = tf.reshape(loc_feature_batch[0, :, :], [-1, 2]) prototype_feature_list = [] prototype_label_list = [] for bs in range(batch_size): cur_labels = tf.reshape(labels[bs], [-1]) cur_cluster_labels = tf.reshape(initial_cluster_labels, [-1]) cur_embedding = tf.reshape(embedding[bs], [-1, args.embedding_dim]) (prototype_features, prototype_labels, _) = eval_utils.extract_trained_prototypes( cur_embedding, location_features, cur_cluster_labels, args.num_clusters * args.num_clusters, args.kmeans_iterations, cur_labels, 1, args.ignore_label, 'semantic') prototype_feature_list.append(prototype_features) prototype_label_list.append(prototype_labels) prototype_features = tf.concat(prototype_feature_list, axis=0) prototype_labels = tf.concat(prototype_label_list, axis=0) # Set up tf session and initialize variables. config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) init = tf.global_variables_initializer() sess.run(init) sess.run(tf.local_variables_initializer()) # Load weights. loader = tf.train.Saver(var_list=restore_var) if args.restore_from is not None: load(loader, sess, args.restore_from) # Start queue threads. threads = tf.train.start_queue_runners(coord=coord, sess=sess) # Create directory for saving prototypes. save_dir = os.path.join(args.save_dir, 'prototypes') if not os.path.isdir(save_dir): os.makedirs(save_dir) # Iterate over testing steps. with open(args.data_list, 'r') as listf: num_steps = len(listf.read().split('\n'))-1 for step in range(num_steps): (batch_prototype_features_np, batch_prototype_labels_np) = sess.run( [prototype_features, prototype_labels]) if step == 0: prototype_features_np = batch_prototype_features_np prototype_labels_np = batch_prototype_labels_np else: prototype_features_np = np.concatenate( [prototype_features_np, batch_prototype_features_np], axis=0) prototype_labels_np = np.concatenate( [prototype_labels_np, batch_prototype_labels_np], axis=0) if (step + 1) % 100 == 0: print('Processed batches: ', (step + 1), '/', num_steps) print ('Total number of prototypes extracted: ', len(prototype_labels_np)) np.save( tf.gfile.Open('%s/%s.npy' % (save_dir, 'prototype_features'), mode='w'), prototype_features_np) np.save( tf.gfile.Open('%s/%s.npy' % (save_dir, 'prototype_labels'), mode='w'), prototype_labels_np) coord.request_stop() coord.join(threads)
def main(): """Create the model and start training. """ # Read CL arguments and snapshot the arguments into text file. args = get_arguments() utils.general.snapshot_arg(args) # The segmentation network is stride 8 by default. h, w = map(int, args.input_size.split(',')) input_size = (h, w) innet_size = (int(math.ceil(h/8)), int(math.ceil(w/8))) # Initialize the random seed. tf.set_random_seed(args.random_seed) # Create queue coordinator. coord = tf.train.Coordinator() # current step step_ph = tf.placeholder(dtype=tf.float32, shape=()) # Load the data reader. with tf.device('/cpu:0'): with tf.name_scope('create_inputs'): reader = SegSortUnsupImageReader( args.data_dir, args.data_list, input_size, args.random_scale, args.random_mirror, args.random_crop, args.ignore_label, IMG_MEAN) image_batch, _, cluster_label_batch = ( reader.dequeue(args.batch_size)) # Shrink labels to the size of the network output. cluster_labels = tf.image.resize_nearest_neighbor( cluster_label_batch, innet_size) # images_mgpu = custom_split(image_batch, args.num_gpu) # Create network and predictions. with tf.device('/gpu:1'): outputs = model(image_batch, args.embedding_dim, args.is_training, args.use_global_status) # Grab variable names which should be restored from checkpoints. restore_var = [ v for v in tf.global_variables() if 'block5' not in v.name or not args.not_restore_classifier ] # Collect embedding from each gpu. with tf.device('/gpu:{:d}'.format(args.num_gpu-1)): # embedding_list = [output[0] for output in outputs] # embedding = tf.concat(embedding_list, axis=0) # Add Unsupervised SegSort loss. seg_losses = train_utils.add_unsupervised_segsort_loss( outputs[0], args.concentration, cluster_labels, ) # Define weight regularization loss. w = args.weight_decay l2_losses = [w*tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'weights' in v.name] # Sum all loss terms. mean_seg_loss = seg_losses mean_l2_loss = tf.add_n(l2_losses) reduced_loss = mean_seg_loss + mean_l2_loss # Grab variable names which are used for training. all_trainable = tf.trainable_variables() fc_trainable = [v for v in all_trainable if 'block5' in v.name] # lr*10 base_trainable = [v for v in all_trainable if 'block5' not in v.name] # lr*1 # Computes gradients per iteration. grads = tf.gradients(reduced_loss, base_trainable+fc_trainable) grads_base = grads[:len(base_trainable)] grads_fc = grads[len(base_trainable):] # Define optimisation parameters. base_lr = tf.constant(args.learning_rate) pow_till = args.num_steps pow_till = 100000 learning_rate = tf.scalar_mul( base_lr, tf.pow((1-step_ph/pow_till), args.power)) opt_base = tf.train.MomentumOptimizer(learning_rate*1.0, args.momentum) opt_fc = tf.train.MomentumOptimizer(learning_rate*10.0, args.momentum) # Define tensorflow operations which apply gradients to update variables. train_op_base = opt_base.apply_gradients(zip(grads_base, base_trainable)) train_op_fc = opt_fc.apply_gradients(zip(grads_fc, fc_trainable)) train_op = tf.group(train_op_base, train_op_fc) # Process for visualisation. with tf.device('/cpu:0'): # Image summary for input image, ground-truth label and prediction. output_vis = tf.image.resize_nearest_neighbor( outputs[-1], tf.shape(image_batch)[1:3,]) output_vis = tf.argmax(output_vis, axis=3) output_vis = tf.expand_dims(output_vis, dim=3) output_vis = tf.cast(output_vis, dtype=tf.uint8) labels_vis = tf.cast(cluster_label_batch, dtype=tf.uint8) in_summary = tf.py_func( utils.general.inv_preprocess, [image_batch, IMG_MEAN], tf.uint8) gt_summary = tf.py_func( utils.general.decode_labels, [labels_vis, args.num_classes], tf.uint8) out_summary = tf.py_func( utils.general.decode_labels, [output_vis, args.num_classes], tf.uint8) # Concatenate image summaries in a row. total_summary = tf.summary.image( 'images', tf.concat(axis=2, values=[in_summary, gt_summary, out_summary]), max_outputs=args.batch_size) # Scalar summary for different loss terms. seg_loss_summary = tf.summary.scalar( 'seg_loss', mean_seg_loss) total_summary = tf.summary.merge_all() summary_writer = tf.summary.FileWriter( args.snapshot_dir, graph=tf.get_default_graph()) # Set up tf session and initialize variables. config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) init = tf.global_variables_initializer() sess.run(init) # Saver for storing checkpoints of the model. saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10) # Load variables if the checkpoint is provided. if args.restore_from is not None and len(args.restore_from) > 0: loader = tf.train.Saver(var_list=restore_var) load(loader, sess, args.restore_from) # Start queue threads. threads = tf.train.start_queue_runners( coord=coord, sess=sess) # Iterate over training steps. pbar = tqdm(range(args.num_steps)) for step in pbar: start_time = time.time() feed_dict = {step_ph : step} step_loss = 0 for it in range(args.iter_size): # Update summary periodically. if it == args.iter_size-1 and step % args.update_tb_every == 0: sess_outs = [reduced_loss, total_summary, train_op] loss_value, summary, _ = sess.run(sess_outs, feed_dict=feed_dict) summary_writer.add_summary(summary, step) else: sess_outs = [reduced_loss, train_op] loss_value, _ = sess.run(sess_outs, feed_dict=feed_dict) step_loss += loss_value step_loss /= args.iter_size lr = sess.run(learning_rate, feed_dict=feed_dict) # Save trained model periodically. if step % args.save_pred_every == 0 and step > 0: save(saver, sess, args.snapshot_dir, step) duration = time.time() - start_time desc = 'loss = {:.3f}, lr = {:.6f}'.format(step_loss, lr) pbar.set_description(desc) coord.request_stop() coord.join(threads)
def main(): """Create the model and start the Inference process. """ args = get_arguments() #TODO:5. postprocession and save # Parse image processing arguments. print('get model!') input_size = parse_commastr(args.input_size) strides = parse_commastr(args.strides) assert(input_size is not None and strides is not None) h, w = input_size #innet_size = (int(math.ceil(h/8)), int(math.ceil(w/8))) # Create input tensor to the Network. crop_image_batch = tf.placeholder( name='crop_image_batch', shape=[8,input_size[0],input_size[1],3], dtype=tf.float32) # Create network and output prediction. outputs = model(crop_image_batch, args.num_classes, False, True) # Grab variable names which should be restored from checkpoints. restore_var = [ v for v in tf.global_variables() if 'crop_image_batch' not in v.name] # Output predictions. output = outputs[-1] output = tf.image.resize_bilinear( output, tf.shape(crop_image_batch)[1:3,]) output = tf.nn.softmax(output, dim=3) # Set up tf session and initialize variables. config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) init = tf.global_variables_initializer() sess.run(init) sess.run(tf.local_variables_initializer()) # Load weights. loader = tf.train.Saver(var_list=restore_var) if args.restore_from is not None: loadckpt(loader, sess, args.restore_from) # Start queue threads. #threads = tf.train.start_queue_runners(coord=coord, sess=sess) for id in range(70): print('-' * 30) print('preprocessing test data...' + str(id)) print('-' * 30) #1. load a medpy file from mem imgs_test, img_test_header = load(args.data + str(id) + '.nii') mm=np.zeros((input_size[0],input_size[1],imgs_test.shape[2])) mm[:imgs_test.shape[0],:imgs_test.shape[1],:imgs_test.shape[2]]=imgs_test imgs_test=mm # load liver mask mask, mask_header = load(args.liver_path + str(id) + '-ori.nii') mask[mask == 2] = 1 mask = ndimage.binary_dilation(mask, iterations=1).astype(mask.dtype) print('-' * 30) print('Predicting masks on test data...' + str(id)) print('-' * 30) index = np.where(mask == 1) mini = np.min(index, axis=-1) maxi = np.max(index, axis=-1) batch = 1 img_deps = input_size[0] img_rows = input_size[1] img_cols = 8 window_cols = (img_cols / 4) count = 0 box_test = np.zeros((batch, img_deps, img_rows, img_cols, 1), dtype="float32") x = imgs_test.shape[0] y = imgs_test.shape[1] z = imgs_test.shape[2] right_cols = int(min(z, maxi[2] + 10) - img_cols) left_cols = max(0, min(mini[2] - 5, right_cols)) score = np.zeros((x, y, z, 3), dtype='float32') score_num = np.zeros((x, y, z, 3), dtype='int16') for cols in xrange(left_cols, right_cols + window_cols, window_cols): # print ('and', z-img_cols,z) if cols > z - img_cols: patch_test = imgs_test[0:img_deps, 0:img_rows, z - img_cols:z] box_test[count, :, :, :, 0] = patch_test incol = box_test.shape[3] box_testt = tans2d(box_test, incol) box_testt = (box_testt + 250) * 255 / 500 box_testt -= np.array((122.675, 122.669, 122.008), dtype=np.float32) # print ('final', img_cols-window_cols, img_cols) feed_dict = {crop_image_batch: box_testt} patch_test_mask = sess.run(output, feed_dict=feed_dict) patch_test_mask = trans3d(patch_test_mask, incol) patch_test_mask = patch_test_mask[:, :, :, 1:-1, :] for i in xrange(batch): score[0:img_deps, 0:img_rows, z - img_cols + 1:z - 1, :] += patch_test_mask[i] score_num[0:img_deps, 0:img_rows, z - img_cols + 1:z - 1, :] += 1 else: patch_test = imgs_test[0:img_deps, 0:img_rows, cols:cols + img_cols] # print(patch_test.shape) box_test[count, :, :, :, 0] = patch_test incol = box_test.shape[3] box_testt = tans2d(box_test, incol) box_testt = (box_testt + 250) * 255 / 500 box_testt -= np.array((122.675, 122.669, 122.008), dtype=np.float32) feed_dict = {crop_image_batch: box_testt} patch_test_mask = sess.run(output, feed_dict=feed_dict) patch_test_mask = trans3d(patch_test_mask, incol) patch_test_mask = patch_test_mask[:, :, :, 1:-1, :] for i in xrange(batch): score[0:img_deps, 0:img_rows, cols + 1:cols + img_cols - 1, :] += patch_test_mask[i] score_num[0:img_deps, 0:img_rows, cols + 1:cols + img_cols - 1, :] += 1 score = score / (score_num + 1e-4) result1 = score[:512, :512, :, 1] result2 = score[:512, :512, :, 2] result1[result1 >= args.thres_liver] = 1 result1[result1 < args.thres_liver] = 0 result2[result2 >= args.thres_tumor] = 1 result2[result2 < args.thres_tumor] = 0 result1[result2 == 1] = 1 print('-' * 30) print('Postprocessing on mask ...' + str(id)) print('-' * 30) # preserve the largest liver Segmask = result2 box = [] [liver_res, num] = measure.label(result1, return_num=True) region = measure.regionprops(liver_res) for i in range(num): box.append(region[i].area) label_num = box.index(max(box)) + 1 liver_res[liver_res != label_num] = 0 liver_res[liver_res == label_num] = 1 # preserve the largest liver mask = ndimage.binary_dilation(mask, iterations=1).astype(mask.dtype) box = [] [liver_labels, num] = measure.label(mask, return_num=True) region = measure.regionprops(liver_labels) for i in range(num): box.append(region[i].area) label_num = box.index(max(box)) + 1 liver_labels[liver_labels != label_num] = 0 liver_labels[liver_labels == label_num] = 1 liver_labels = ndimage.binary_fill_holes(liver_labels).astype(int) # preserve tumor within ' largest liver' only Segmask = Segmask * liver_labels Segmask = ndimage.binary_fill_holes(Segmask).astype(int) Segmask = np.array(Segmask, dtype='uint8') liver_res = np.array(liver_res, dtype='uint8') liver_res = ndimage.binary_fill_holes(liver_res).astype(int) liver_res[Segmask == 1] = 2 liver_res = np.array(liver_res, dtype='uint8') save(liver_res, args.save_path + 'test-segmentation-' + str(id) + '.nii', img_test_header)
def main(): """Create the model and start the Inference process. """ args = get_arguments() # Parse image processing arguments. input_size = parse_commastr(args.input_size) strides = parse_commastr(args.strides) assert (input_size is not None and strides is not None) h, w = input_size innet_size = (int(math.ceil(h / 8)), int(math.ceil(w / 8))) # Create queue coordinator. coord = tf.train.Coordinator() # Load the data reader. with tf.name_scope('create_inputs'): reader = ImageReader( args.data_dir, args.data_list, None, False, # No random scale. False, # No random mirror. False, # No random crop, center crop instead args.ignore_label, IMG_MEAN) image = reader.image image_list = reader.image_list image_batch = tf.expand_dims(image, dim=0) # Create input tensor to the Network. crop_image_batch = tf.placeholder( name='crop_image_batch', shape=[1, input_size[0], input_size[1], 3], dtype=tf.float32) # Create network and output prediction. outputs = model(crop_image_batch, args.num_classes, False, True) # Grab variable names which should be restored from checkpoints. restore_var = [ v for v in tf.global_variables() if 'crop_image_batch' not in v.name ] # Output predictions. output = outputs[-1] output = tf.image.resize_bilinear(output, tf.shape(crop_image_batch)[1:3, ]) output = tf.nn.softmax(output, axis=3) # Set up tf session and initialize variables. config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) init = tf.global_variables_initializer() sess.run(init) sess.run(tf.local_variables_initializer()) # Load weights. loader = tf.train.Saver(var_list=restore_var) if args.restore_from is not None: load(loader, sess, args.restore_from) # Start queue threads. threads = tf.train.start_queue_runners(coord=coord, sess=sess) # Get colormap. map_data = scipy.io.loadmat(args.colormap) key = os.path.basename(args.colormap).replace('.mat', '') colormap = map_data[key] colormap *= 255 colormap = colormap.astype(np.uint8) # Create directory for saving predictions. pred_dir = os.path.join(args.save_dir, 'gray') color_dir = os.path.join(args.save_dir, 'color') if not os.path.isdir(pred_dir): os.makedirs(pred_dir) if not os.path.isdir(color_dir): os.makedirs(color_dir) # Iterate over testing steps. with open(args.data_list, 'r') as listf: num_steps = len(listf.read().split('\n')) - 1 for step in range(num_steps): img_batch = sess.run(image_batch) img_size = img_batch.shape padimg_size = list(img_size) # deep copy of img_size padimg_h, padimg_w = padimg_size[1:3] input_h, input_w = input_size if input_h > padimg_h: padimg_h = input_h if input_w > padimg_w: padimg_w = input_w # Update padded image size. padimg_size[1] = padimg_h padimg_size[2] = padimg_w padimg_batch = np.zeros(padimg_size, dtype=np.float32) img_h, img_w = img_size[1:3] padimg_batch[:, :img_h, :img_w, :] = img_batch # Create padded label array. lab_size = list(padimg_size) lab_size[-1] = args.num_classes lab_batch = np.zeros(lab_size, dtype=np.float32) lab_batch.fill(args.ignore_label) stride_h, stride_w = strides npatches_h = math.ceil(1.0 * (padimg_h - input_h) / stride_h) + 1 npatches_w = math.ceil(1.0 * (padimg_w - input_w) / stride_w) + 1 # Crate the ending index of each patch. patch_indh = np.linspace(input_h, padimg_h, npatches_h, dtype=np.int32) patch_indw = np.linspace(input_w, padimg_w, npatches_w, dtype=np.int32) for indh in patch_indh: for indw in patch_indw: sh, eh = indh - input_h, indh # start&end ind of H sw, ew = indw - input_w, indw # start&end ind of W cropimg_batch = padimg_batch[:, sh:eh, sw:ew, :] feed_dict = {crop_image_batch: cropimg_batch} out = sess.run(output, feed_dict=feed_dict) lab_batch[:, sh:eh, sw:ew, :] += out lab_batch = lab_batch[0, :img_h, :img_w, :] lab_batch = np.argmax(lab_batch, axis=-1) lab_batch = lab_batch.astype(np.uint8) basename = os.path.basename(image_list[step]) basename = basename.replace('jpg', 'png') predname = os.path.join(pred_dir, basename) Image.fromarray(lab_batch, mode='L').save(predname) colorname = os.path.join(color_dir, basename) color = colormap[lab_batch] Image.fromarray(color, mode='RGB').save(colorname) coord.request_stop() coord.join(threads)
def main(): """Creates the model and start the inference process.""" args = get_arguments() # Parse image processing arguments. input_size = parse_commastr(args.input_size) strides = parse_commastr(args.strides) assert (input_size is not None and strides is not None) h, w = input_size innet_size = (int(math.ceil(h / 8)), int(math.ceil(w / 8))) # Create queue coordinator. coord = tf.train.Coordinator() # Load the data reader. with tf.name_scope('create_inputs'): reader = SegSortImageReader( args.data_dir, args.data_list, None, False, # No random scale False, # No random mirror False, # No random crop, center crop instead args.ignore_label, IMG_MEAN) image = reader.image[:reader.height, :reader.width] image_list = reader.image_list image_batch = tf.expand_dims(image, dim=0) # Create multi-scale augmented datas. rescale_image_batches = [] is_flipped = [] scales = [0.5, 0.75, 1, 1.25, 1.5, 1.75] if args.scale_aug else [1] for scale in scales: h_new = tf.to_int32( tf.multiply(tf.to_float(tf.shape(image_batch)[1]), scale)) w_new = tf.to_int32( tf.multiply(tf.to_float(tf.shape(image_batch)[2]), scale)) new_shape = tf.stack([h_new, w_new]) new_image_batch = tf.image.resize_images(image_batch, new_shape) rescale_image_batches.append(new_image_batch) is_flipped.append(False) # Create horizontally flipped augmented datas. if args.flip_aug: for i in range(len(scales)): img = rescale_image_batches[i] is_flip = is_flipped[i] img = tf.squeeze(img, axis=0) flip_img = tf.image.flip_left_right(img) flip_img = tf.expand_dims(flip_img, axis=0) rescale_image_batches.append(flip_img) is_flipped.append(True) # Create input tensor to the Network. crop_image_batch = tf.placeholder( name='crop_image_batch', shape=[1, input_size[0], input_size[1], 3], dtype=tf.float32) # Create network. outputs = model(crop_image_batch, args.embedding_dim, False, True) # Grab variable names which should be restored from checkpoints. restore_var = [ v for v in tf.global_variables() if 'crop_image_batch' not in v.name ] # Output predictions. output = outputs[0] output = tf.image.resize_bilinear(output, [input_size[0], input_size[1]]) embedding = common_utils.normalize_embedding(output) # Prototype placeholders. prototype_features = tf.placeholder(tf.float32, shape=[None, args.embedding_dim]) prototype_labels = tf.placeholder(tf.int32) # Combine embedding with location features and kmeans shape = embedding.get_shape().as_list() loc_feature = tf.expand_dims( common_utils.generate_location_features([shape[1], shape[2]], 'float'), 0) embedding_with_location = tf.concat([embedding, loc_feature], 3) embedding_with_location = common_utils.normalize_embedding( embedding_with_location) # Perform Kmeans clustering and extract prototypes. cluster_labels = common_utils.kmeans( embedding_with_location, [args.num_clusters, args.num_clusters], args.kmeans_iterations) _, cluster_labels = tf.unique(tf.reshape(cluster_labels, [-1])) test_prototypes = common_utils.calculate_prototypes_from_labels( embedding, cluster_labels) # Predict semantic labels. similarities = tf.matmul(test_prototypes, prototype_features, transpose_b=True) _, k_predictions = tf.nn.top_k(similarities, k=args.k_in_nearest_neighbors, sorted=True) k_predictions = tf.gather(prototype_labels, k_predictions) k_predictions = tf.gather(k_predictions, cluster_labels) k_predictions = tf.reshape( k_predictions, [shape[0], shape[1], shape[2], args.k_in_nearest_neighbors]) # Set up tf session and initialize variables. config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) init = tf.global_variables_initializer() sess.run(init) sess.run(tf.local_variables_initializer()) # Load weights. loader = tf.train.Saver(var_list=restore_var) if args.restore_from is not None: load(loader, sess, args.restore_from) # Start queue threads. threads = tf.train.start_queue_runners(coord=coord, sess=sess) # Get colormap. map_data = scipy.io.loadmat(args.colormap) key = os.path.basename(args.colormap).replace('.mat', '') colormap = map_data[key] colormap *= 255 colormap = colormap.astype(np.uint8) # Create directory for saving predictions. pred_dir = os.path.join(args.save_dir, 'gray') color_dir = os.path.join(args.save_dir, 'color') if not os.path.isdir(pred_dir): os.makedirs(pred_dir) if not os.path.isdir(color_dir): os.makedirs(color_dir) # Iterate over testing steps. with open(args.data_list, 'r') as listf: num_steps = len(listf.read().split('\n')) - 1 # Load prototype features and labels. prototype_features_np = np.load( os.path.join(args.prototype_dir, 'prototype_features.npy')) prototype_labels_np = np.load( os.path.join(args.prototype_dir, 'prototype_labels.npy')) feed_dict = { prototype_features: prototype_features_np, prototype_labels: prototype_labels_np } pbar = tqdm(range(num_steps)) for step in pbar: rescale_img_batches = sess.run(rescale_image_batches) # Final segmentation results (average across multiple scales). scale_ind = 2 if args.scale_aug else 0 final_lab_size = list(rescale_img_batches[scale_ind].shape[1:]) final_lab_size[-1] = args.num_classes final_lab_batch = np.zeros(final_lab_size) # Iterate over multiple scales. for img_batch, is_flip in zip(rescale_img_batches, is_flipped): img_size = img_batch.shape padimg_size = list(img_size) # deep copy of img_size padimg_h, padimg_w = padimg_size[1:3] input_h, input_w = input_size if input_h > padimg_h: padimg_h = input_h if input_w > padimg_w: padimg_w = input_w # Update padded image size. padimg_size[1] = padimg_h padimg_size[2] = padimg_w padimg_batch = np.zeros(padimg_size, dtype=np.float32) img_h, img_w = img_size[1:3] padimg_batch[:, :img_h, :img_w, :] = img_batch stride_h, stride_w = strides npatches_h = math.ceil(1.0 * (padimg_h - input_h) / stride_h) + 1 npatches_w = math.ceil(1.0 * (padimg_w - input_w) / stride_w) + 1 # Create padded prediction array. pred_size = list(padimg_size) pred_size[-1] = args.num_classes predictions_np = np.zeros(pred_size, dtype=np.int32) # Create the ending index of each patch. patch_indh = np.linspace(input_h, padimg_h, npatches_h, dtype=np.int32) patch_indw = np.linspace(input_w, padimg_w, npatches_w, dtype=np.int32) pred_size[-1] = args.embedding_dim for indh in patch_indh: for indw in patch_indw: sh, eh = indh - input_h, indh # start & end ind of H sw, ew = indw - input_w, indw # start & end ind of W cropimg_batch = padimg_batch[:, sh:eh, sw:ew, :] feed_dict[crop_image_batch] = cropimg_batch k_predictions_np = sess.run(k_predictions, feed_dict=feed_dict) # Sum up KNN votes. # This is the speed bottleneck for multiscale inference. # Use singlescale inference for fast results. # TODO: Either compute on GPU or change a way of implementation. for c in range(args.num_classes): predictions_np[:, sh:eh, sw:ew, c] += np.sum( (k_predictions_np == c).astype(np.int), axis=3) predictions_np = predictions_np[0, :img_h, :img_w, :] lab_batch = predictions_np.astype(np.float32) # Rescale prediction back to original resolution. lab_batch = cv2.resize(lab_batch, (final_lab_size[1], final_lab_size[0]), interpolation=cv2.INTER_LINEAR) if is_flip: # Flipped prediction back to original orientation. lab_batch = lab_batch[:, ::-1, :] final_lab_batch += lab_batch final_lab_ind = np.argmax(final_lab_batch, axis=-1) final_lab_ind = final_lab_ind.astype(np.uint8) basename = os.path.basename(image_list[step]) basename = basename.replace('jpg', 'png') predname = os.path.join(pred_dir, basename) Image.fromarray(final_lab_ind, mode='L').save(predname) colorname = os.path.join(color_dir, basename) color = colormap[final_lab_ind] Image.fromarray(color, mode='RGB').save(colorname) coord.request_stop() coord.join(threads)
def main(): """Create the model and start the Inference process.""" args = get_arguments() # Create queue coordinator. coord = tf.train.Coordinator() # Load the data reader. with tf.name_scope('create_inputs'): reader = SegSortImageReader( args.data_dir, args.data_list, parse_commastr(args.input_size), False, # No random scale False, # No random mirror False, # No random crop, center crop instead args.ignore_label, IMG_MEAN) image_list = reader.image_list image_batch = tf.expand_dims(reader.image, dim=0) label_batch = tf.expand_dims(reader.label, dim=0) cluster_label_batch = tf.expand_dims(reader.cluster_label, dim=0) loc_feature_batch = tf.expand_dims(reader.loc_feature, dim=0) height = reader.height width = reader.width # Create network and output prediction. outputs = model(image_batch, args.embedding_dim, False, True) # Grab variable names which should be restored from checkpoints. restore_var = [ v for v in tf.global_variables() if 'crop_image_batch' not in v.name ] # Output predictions. output = outputs[0] output = tf.image.resize_bilinear(output, tf.shape(image_batch)[1:3, ]) embedding = common_utils.normalize_embedding(output) # Prototype placeholders. prototype_features = tf.placeholder(tf.float32, shape=[None, args.embedding_dim]) prototype_labels = tf.placeholder(tf.int32) # Combine embedding with location features. embedding_with_location = tf.concat([embedding, loc_feature_batch], 3) embedding_with_location = common_utils.normalize_embedding( embedding_with_location) # Kmeans clustering. cluster_labels = common_utils.kmeans( embedding_with_location, [args.num_clusters, args.num_clusters], args.kmeans_iterations) test_prototypes = common_utils.calculate_prototypes_from_labels( embedding, cluster_labels) # Predict semantic labels. semantic_predictions, _ = eval_utils.predict_semantic_instance_labels( cluster_labels, test_prototypes, prototype_features, prototype_labels, None, args.k_in_nearest_neighbors) semantic_predictions = tf.cast(semantic_predictions, tf.uint8) semantic_predictions = tf.squeeze(semantic_predictions) # Set up tf session and initialize variables. config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) init = tf.global_variables_initializer() sess.run(init) sess.run(tf.local_variables_initializer()) # Load weights. loader = tf.train.Saver(var_list=restore_var) if args.restore_from is not None: load(loader, sess, args.restore_from) # Start queue threads. threads = tf.train.start_queue_runners(coord=coord, sess=sess) # Get colormap. map_data = scipy.io.loadmat(args.colormap) key = os.path.basename(args.colormap).replace('.mat', '') colormap = map_data[key] colormap *= 255 colormap = colormap.astype(np.uint8) # Create directory for saving predictions. pred_dir = os.path.join(args.save_dir, 'gray') color_dir = os.path.join(args.save_dir, 'color') if not os.path.isdir(pred_dir): os.makedirs(pred_dir) if not os.path.isdir(color_dir): os.makedirs(color_dir) # Iterate over testing steps. with open(args.data_list, 'r') as listf: num_steps = len(listf.read().split('\n')) - 1 # Load prototype features and labels. prototype_features_np = np.load( os.path.join(args.prototype_dir, 'prototype_features.npy')) prototype_labels_np = np.load( os.path.join(args.prototype_dir, 'prototype_labels.npy')) feed_dict = { prototype_features: prototype_features_np, prototype_labels: prototype_labels_np } for step in tqdm(range(num_steps)): semantic_predictions_np, height_np, width_np = sess.run( [semantic_predictions, height, width], feed_dict=feed_dict) semantic_predictions_np = semantic_predictions_np[:height_np, : width_np] basename = os.path.basename(image_list[step]) basename = basename.replace('jpg', 'png') predname = os.path.join(pred_dir, basename) Image.fromarray(semantic_predictions_np, mode='L').save(predname) colorname = os.path.join(color_dir, basename) color = colormap[semantic_predictions_np] Image.fromarray(color, mode='RGB').save(colorname) coord.request_stop() coord.join(threads)