def preprocessing_fn_segmentation(image, mask, out_shape): return segmentation_preprocessing.segmentation_preprocessing( image, mask, out_shape, is_training)
def evulate_dir_nii_weakly(): ''' 新版本的评估代码 :return: ''' import os from metrics import dice, IoU from datasets.medicalImage import convertCase2PNGs, image_expand case_names = util.io.ls(FLAGS.dataset_dir) case_names.sort() restore_path = '/home/give/PycharmProjects/weakly_label_segmentation/logs/DSW/DSW-upsampling-True-1-False-True-True/model.ckpt-10450' with tf.name_scope('test'): image = tf.placeholder(dtype=tf.uint8, shape=[None, None, 3]) input_shape_placeholder = tf.placeholder(tf.int32, shape=[2]) processed_image = segmentation_preprocessing.segmentation_preprocessing(image, None, None, out_shape=[256, 256], is_training=False) b_image = tf.expand_dims(processed_image, axis=0) net = UNetBlocksMS.UNet(b_image, None, None, is_training=False, decoder=FLAGS.decoder, update_center_flag=FLAGS.update_center, batch_size=2, init_center_value=None, update_center_strategy=1, full_annotation_flag=False, output_shape_tensor=input_shape_placeholder) # print slim.get_variables_to_restore() global_step = slim.get_or_create_global_step() sess_config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True) if FLAGS.gpu_memory_fraction < 0: sess_config.gpu_options.allow_growth = True elif FLAGS.gpu_memory_fraction > 0: sess_config.gpu_options.per_process_gpu_memory_fraction = FLAGS.gpu_memory_fraction # Variables to restore: moving avg. or normal weights. if FLAGS.using_moving_average: variable_averages = tf.train.ExponentialMovingAverage( FLAGS.moving_average_decay) variables_to_restore = variable_averages.variables_to_restore() variables_to_restore[global_step.op.name] = global_step else: variables_to_restore = slim.get_variables_to_restore() saver = tf.train.Saver() checkpoint = restore_path # pixel_recovery_features = tf.image.resize_images(net.pixel_recovery_features, image_shape_placeholder) with tf.Session(config=sess_config) as sess: saver.restore(sess, checkpoint) global_gt = [] global_pred = [] global_pred_kmeans = [] global_pred_centers = [] case_dices = [] case_IoUs = [] case_dices_kmeans = [] # kmeans case_IoUs_kmeans = [] # kmeans case_dices_centers = [] # net centers case_IoUs_centers = [] # net centers for iter, case_name in enumerate(case_names): # image_data = util.img.imread( # glob(util.io.join_path(FLAGS.dataset_dir, case_name, 'images', '*.png'))[0]) # gt = util.img.imread(util.io.join_path(FLAGS.dataset_dir, case_name, 'weakly_label_whole_mask.png'))[:, :, 1] image_data = cv2.imread( glob(util.io.join_path(FLAGS.dataset_dir, case_name, 'images', '*.png'))[0]) gt = cv2.imread(glob(util.io.join_path(FLAGS.dataset_dir, case_name, 'whole_mask.png'))[0])[ :, :, 1] gt = cv2.resize(gt, (256, 256), interpolation=cv2.INTER_NEAREST) # if case_name == '0402a81e75262469925ea893b6706183832e85324f7b1e08e634129f5d522cdd': # gt = np.transpose(gt, axes=[1, 0]) print(np.shape(image_data), np.shape(gt)) recover_image, pixel_recover_feature, pixel_cls_scores, b_image_v, global_step_v, net_centers = sess.run( [net.pixel_recovery_logits, net.pixel_recovery_features, net.pixel_cls_scores, b_image, global_step, net.centers], feed_dict={ image: image_data, input_shape_placeholder: [256, 256], }) pred = np.asarray(pixel_cls_scores[0, :, :, 1] > 0.5, np.uint8) pixel_recover_feature = pixel_recover_feature[0] # pred = cv2.resize(pred, tuple(np.shape(image_data)[:2][::-1])) # pixel_recover_feature = cv2.resize(pixel_recover_feature, tuple(np.shape(image_data)[:2][::-1])) print(np.shape(pred), case_name) cv2.imwrite('./tmp/%s_gt.png' % (case_name), np.asarray(gt * 200, np.uint8)) cv2.imwrite('./tmp/%s_img.png' % (case_name), np.asarray(cv2.resize(image_data, (256, 256)), np.uint8)) cv2.imwrite('./tmp/%s_recover_img.png' % (case_name), np.asarray(recover_image[0] * 255., np.uint8)) cv2.imwrite('./tmp/%s_pred.png' % (case_name), np.asarray(pred * 200, np.uint8)) # 开操作 先腐蚀,后膨胀 # 闭操作 先膨胀,后腐蚀 # pred = close_operation(pred, kernel_size=3) # pred = open_operation(pred, kernel_size=3) # pred = fill_region(pred) # 再计算kmeans 的结果 pred_kmeans = np.asarray(image_expand(pred, kernel_size=5), np.uint8) # pred_seg = image_expand(pred_seg, 5) print(np.shape(pred_kmeans), np.shape(gt), np.shape(pixel_recover_feature)) pred_kmeans = cluster_postprocessing(pred_kmeans, gt, pixel_recover_feature, k=2) # pred_kmeans[image_data[:, :, 1] < (10./255.)] = 0 # pred_kmeans = close_operation(pred_kmeans, kernel_size=5) pred_kmeans = fill_region(pred_kmeans) # 计算根据center得到的结果 # pixel_recover_feature, net_centers # pred_centers = np.asarray(image_expand(pred, kernel_size=5), np.uint8) pred_centers = np.asarray(pred, np.uint8) pred_centers = net_center_posprocessing(pred_centers, centers=net_centers, pixel_wise_feature=pixel_recover_feature, gt=gt) # pred_centers[image_data[:, :, 1] < (10. / 255.)] = 0 # pred_centers = close_operation(pred_centers, kernel_size=5) pred_centers = fill_region(pred_centers) cv2.imwrite('./tmp/%s_pred_center.png' % (case_name), np.asarray(pred_centers * 200, np.uint8)) cv2.imwrite('./tmp/%s_pred_kmeans.png' % (case_name), np.asarray(pred_kmeans * 200, np.uint8)) global_gt.append(gt) global_pred.append(pred) global_pred_kmeans.append(pred_kmeans) global_pred_centers.append(pred_centers) case_dice = dice(gt, pred) case_IoU = IoU(gt, pred) case_dice_kmeans = dice(gt, pred_kmeans) case_IoU_kmeans = IoU(gt, pred_kmeans) case_dice_centers = dice(gt, pred_centers) case_IoU_centers = IoU(gt, pred_centers) print('case dice: ', case_dice) print('case IoU: ', case_IoU) print('case dice kmeans: ', case_dice_kmeans) print('case IoU kmeans: ', case_IoU_kmeans) print('case dice centers: ', case_dice_centers) print('case IoU centers: ', case_IoU_centers) case_dices.append(case_dice) case_IoUs.append(case_IoU) case_dices_kmeans.append(case_dice_kmeans) case_IoUs_kmeans.append(case_IoU_kmeans) case_dices_centers.append(case_dice_centers) case_IoUs_centers.append(case_IoU_centers) print 'global dice is ', dice(global_gt, global_pred) print 'global IoU is ', IoU(global_gt, global_pred) print('mean of case dice is ', np.mean(case_dices)) print('mean of case IoU is ', np.mean(case_IoUs)) print 'global dice (kmeans) is ', dice(global_gt, global_pred_kmeans) print 'global IoU (kmeans) is ', IoU(global_gt, global_pred_kmeans) print 'mean of case dice (kmeans) is ', np.mean(case_dices_kmeans) print 'mean of case IoU (kmenas) is ', np.mean(case_IoUs_kmeans) print 'global dice (centers) is ', dice(global_gt, global_pred_centers) print 'global IoU (centers) is ', IoU(global_gt, global_pred_centers) print 'mean of case dice (centers) is ', np.mean(case_dices_centers) print 'mean of case IoU (centers) is ', np.mean(case_IoUs_centers)
def test(): with tf.name_scope('test'): image = tf.placeholder(dtype=tf.int32, shape = [None, None, 3]) image_shape = tf.placeholder(dtype = tf.int32, shape = [3, ]) processed_image = segmentation_preprocessing.segmentation_preprocessing(image, None, out_shape=[256, 256], is_training=False) b_image = tf.expand_dims(processed_image, axis = 0) net = UNet.UNet(b_image, None, is_training=False, decoder=FLAGS.decoder) global_step = slim.get_or_create_global_step() sess_config = tf.ConfigProto(log_device_placement = False, allow_soft_placement = True) if FLAGS.gpu_memory_fraction < 0: sess_config.gpu_options.allow_growth = True elif FLAGS.gpu_memory_fraction > 0: sess_config.gpu_options.per_process_gpu_memory_fraction = FLAGS.gpu_memory_fraction; checkpoint_dir = util.io.get_dir(FLAGS.checkpoint_path) logdir = util.io.join_path(checkpoint_dir, 'test', FLAGS.dataset_name + '_' +FLAGS.dataset_split_name) # Variables to restore: moving avg. or normal weights. if FLAGS.using_moving_average: variable_averages = tf.train.ExponentialMovingAverage( FLAGS.moving_average_decay) variables_to_restore = variable_averages.variables_to_restore() variables_to_restore[global_step.op.name] = global_step else: variables_to_restore = slim.get_variables_to_restore() saver = tf.train.Saver(var_list = variables_to_restore) case_names = util.io.ls(FLAGS.dataset_dir) case_names.sort() checkpoint = FLAGS.checkpoint_path checkpoint_name = util.io.get_filename(str(checkpoint)) IoUs = [] dices = [] with tf.Session(config = sess_config) as sess: saver.restore(sess, checkpoint) centers = sess.run(net.centers) for iter, case_name in enumerate(case_names): image_data = util.img.imread( glob(util.io.join_path(FLAGS.dataset_dir, case_name, 'images', '*.png'))[0], rgb=True) pixel_cls_scores, pixel_recovery_feature_map = sess.run( [net.pixel_cls_scores, net.pixel_recovery_features[-1]], feed_dict = { image: image_data }) print '%d/%d: %s'%(iter + 1, len(case_names), case_name), np.shape(pixel_cls_scores) pos_score = np.asarray(pixel_cls_scores > 0.5, np.uint8)[0, :, :, 1] pred = cv2.resize(pos_score, tuple(np.shape(image_data)[:2][::-1]), interpolation=cv2.INTER_NEAREST) gt = util.img.imread(util.io.join_path(FLAGS.dataset_dir, case_name, 'weakly_label_whole_mask.png'))[:, :, 1] intersection = np.sum(np.logical_and(gt != 0, pred != 0)) union = np.sum(gt != 0) + np.sum(pred != 0) IoU = (1.0 * intersection) / (1.0 * union - 1.0 * intersection) * 100 dice = (2.0 * intersection) / (1.0 * union) * 100 IoUs.append(IoU) dices.append(dice) cv2.imwrite(util.io.join_path(FLAGS.pred_path, case_name + '.png'), pred) cv2.imwrite(util.io.join_path(FLAGS.pred_vis_path, case_name + '.png'), np.asarray(pred * 200)) assign_label = get_assign_label(centers, pixel_recovery_feature_map)[0] assign_label = assign_label * pos_score assign_label += 1 cv2.imwrite(util.io.join_path(FLAGS.pred_assign_label_path, case_name + '.png'), np.asarray(assign_label * 100, np.uint8)) print('total mean of IoU is ', np.mean(IoUs)) print('total mean of dice is ', np.mean(dices))
def generate_recovery_image_feature_map(): with tf.name_scope('test'): image = tf.placeholder(dtype=tf.uint8, shape=[None, None, 3]) image_shape_placeholder = tf.placeholder(tf.int32, shape=[2]) processed_image = segmentation_preprocessing.segmentation_preprocessing(image, None, out_shape=[FLAGS.eval_image_width, FLAGS.eval_image_height], is_training=False) b_image = tf.expand_dims(processed_image, axis=0) print('the decoder is ', FLAGS.decoder) net = UNet.UNet(b_image, None, is_training=False, decoder=FLAGS.decoder, update_center_flag=FLAGS.update_center, batch_size=1) # print slim.get_variables_to_restore() global_step = slim.get_or_create_global_step() sess_config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True) if FLAGS.gpu_memory_fraction < 0: sess_config.gpu_options.allow_growth = True elif FLAGS.gpu_memory_fraction > 0: sess_config.gpu_options.per_process_gpu_memory_fraction = FLAGS.gpu_memory_fraction # Variables to restore: moving avg. or normal weights. if FLAGS.using_moving_average: variable_averages = tf.train.ExponentialMovingAverage( FLAGS.moving_average_decay) variables_to_restore = variable_averages.variables_to_restore() variables_to_restore[global_step.op.name] = global_step else: variables_to_restore = slim.get_variables_to_restore() saver = tf.train.Saver() case_names = util.io.ls(FLAGS.dataset_dir) case_names.sort() checkpoint = FLAGS.checkpoint_path checkpoint_name = util.io.get_filename(str(checkpoint)) IoUs = [] dices = [] pixel_recovery_features = tf.image.resize_images(net.pixel_recovery_features, image_shape_placeholder) with tf.Session(config=sess_config) as sess: saver.restore(sess, checkpoint) centers = sess.run(net.centers) print('sum of centers is ', np.sum(centers)) for iter, case_name in enumerate(case_names): image_data = util.img.imread( glob(util.io.join_path(FLAGS.dataset_dir, case_name, 'images', '*.png'))[0], rgb=True) mask = cv2.imread(glob(util.io.join_path(FLAGS.dataset_dir, case_name, 'whole_mask.png'))[0])[ :, :, 0] pixel_cls_scores, recovery_img, recovery_feature_map, b_image_v, global_step_v = sess.run( [net.pixel_cls_scores, net.pixel_recovery_value, pixel_recovery_features, b_image, global_step], feed_dict={ image: image_data, image_shape_placeholder: np.shape(image_data)[:2] }) print global_step_v print '%d / %d: %s' % (iter + 1, len(case_names), case_name), np.shape(pixel_cls_scores), np.max( pixel_cls_scores[:, :, :, 1]), np.min(pixel_cls_scores[:, :, :, 1]), np.shape( recovery_img), np.max(recovery_img), np.min(recovery_img), np.max(b_image_v), np.min( b_image_v), np.shape(b_image_v) print np.shape(recovery_feature_map), np.shape(mask) pred_vis_path = util.io.join_path(FLAGS.pred_vis_dir, case_name + '.png') pred_path = util.io.join_path(FLAGS.pred_dir, case_name + '.png') pos_score = np.asarray(pixel_cls_scores > 0.5, np.uint8)[0, :, :, 1] pred = cv2.resize(pos_score, tuple(np.shape(image_data)[:2][::-1]), interpolation=cv2.INTER_NEAREST) cv2.imwrite(pred_vis_path, np.asarray(pred * 200, np.uint8)) cv2.imwrite(pred_path, np.asarray(pred, np.uint8)) recovery_img_path = util.io.join_path(FLAGS.recovery_img_dir, case_name + '.png') cv2.imwrite(recovery_img_path, np.asarray(recovery_img[0] * 255, np.uint8)) recovery_feature_map_path = util.io.join_path(FLAGS.recovery_feature_map_dir, case_name + '.npy') xs, ys = np.where(mask == 1) features = recovery_feature_map[0][xs, ys, :] print 'the size of feature map is ', np.shape(np.asarray(features, np.float32)) np.save(recovery_feature_map_path, np.asarray(features, np.float32))
def evulate_dir_nii_weakly_new(): ''' 新版本的评估代码 :return: ''' from metrics import dice, IoU from datasets.medicalImage import convertCase2PNGs, image_expand nii_dir = '/home/give/Documents/dataset/ISBI2017/Training_Batch_1' save_dir = '/home/give/Documents/dataset/ISBI2017/weakly_label_segmentation_V4/Batch_1/DLSC_0/niis' # restore_path = '/home/give/PycharmProjects/weakly_label_segmentation/logs/ISBI2017_V2/1s_agumentation_weakly-upsampling-2/model.ckpt-168090' restore_path = '/home/give/PycharmProjects/weakly_label_segmentation/logs/ISBI2017_V2/1s_agumentation_weakly_V3-upsampling-1/model.ckpt-167824' nii_parent_dir = os.path.dirname(nii_dir) with tf.name_scope('test'): image = tf.placeholder(dtype=tf.float32, shape=[None, None, 3]) image_shape_placeholder = tf.placeholder(tf.int32, shape=[2]) input_shape_placeholder = tf.placeholder(tf.int32, shape=[2]) processed_image = segmentation_preprocessing.segmentation_preprocessing( image, None, None, out_shape=input_shape_placeholder, is_training=False) b_image = tf.expand_dims(processed_image, axis=0) net = UNetBlocksMS.UNet(b_image, None, None, is_training=False, decoder=FLAGS.decoder, update_center_flag=FLAGS.update_center, batch_size=2, init_center_value=None, update_center_strategy=1, num_centers_k=FLAGS.num_centers_k, full_annotation_flag=False, output_shape_tensor=input_shape_placeholder) # print slim.get_variables_to_restore() global_step = slim.get_or_create_global_step() sess_config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True) if FLAGS.gpu_memory_fraction < 0: sess_config.gpu_options.allow_growth = True elif FLAGS.gpu_memory_fraction > 0: sess_config.gpu_options.per_process_gpu_memory_fraction = FLAGS.gpu_memory_fraction # Variables to restore: moving avg. or normal weights. if FLAGS.using_moving_average: variable_averages = tf.train.ExponentialMovingAverage( FLAGS.moving_average_decay) variables_to_restore = variable_averages.variables_to_restore() variables_to_restore[global_step.op.name] = global_step else: variables_to_restore = slim.get_variables_to_restore() saver = tf.train.Saver() nii_pathes = glob(os.path.join(nii_dir, 'volume-*.nii')) checkpoint = restore_path # pixel_recovery_features = tf.image.resize_images(net.pixel_recovery_features, image_shape_placeholder) with tf.Session(config=sess_config) as sess: saver.restore(sess, checkpoint) global_gt = [] global_pred = [] global_pred_grabcut = [] case_dices = [] case_IoUs = [] case_dices_grabcut = [] case_IoUs_grabcut = [] for iter, nii_path in enumerate(nii_pathes): # if os.path.basename(nii_path) in ['volume-15.nii', 'volume-25.nii']: # continue if os.path.basename(nii_path) != 'volume-0.nii': continue nii_path_basename = os.path.basename(nii_path) pred_dir = os.path.join( save_dir, nii_path_basename.split('.')[0].split('-')[1], 'pred') pred_vis_dir = os.path.join( save_dir, nii_path_basename.split('.')[0].split('-')[1], 'pred_vis') recovery_img_dir = os.path.join( save_dir, nii_path_basename.split('.')[0].split('-')[1], 'recovery_img') if not os.path.exists(pred_dir): os.makedirs(pred_dir) if not os.path.exists(pred_vis_dir): os.makedirs(pred_vis_dir) if not os.path.exists(recovery_img_dir): os.makedirs(recovery_img_dir) seg_path = os.path.join( nii_dir, 'segmentation-' + nii_path.split('.')[0].split('-')[1] + '.nii') case_preds = [] case_gts = [] case_preds_grabcut = [] # case_recover_features = [] print(nii_path, seg_path) imgs, tumor_masks, liver_masks, tumor_weak_masks = convertCase2PNGs( nii_path, seg_path, save_dir=None) print(len(imgs), len(tumor_masks), len(liver_masks), len(tumor_masks)) for slice_idx, (image_data, liver_mask, whole_mask) in enumerate( zip(imgs, liver_masks, tumor_masks)): pixel_cls_scores_ms = [] pixel_recover_feature_ms = [] for single_scale in scales: pixel_recover_feature, pixel_cls_scores, b_image_v, global_step_v, net_centers = sess.run( [ net.pixel_recovery_features, net.pixel_cls_scores, b_image, global_step, net.centers ], feed_dict={ image: image_data, image_shape_placeholder: np.shape(image_data)[:2], input_shape_placeholder: single_scale }) pixel_cls_scores_ms.append( cv2.resize(pixel_cls_scores[0, :, :, 1], tuple(np.shape(image_data)[:2][::-1]))) pixel_recover_feature_ms.append(pixel_recover_feature[0]) del pixel_recover_feature pixel_cls_scores = np.mean(pixel_cls_scores_ms, axis=0) pixel_recover_feature = np.mean(pixel_recover_feature_ms, axis=0) # case_recover_features.append(pixel_recover_feature) if np.sum(whole_mask) != 0: pred = np.asarray(pixel_cls_scores > 0.6, np.uint8) # 开操作 先腐蚀,后膨胀 # 闭操作 先膨胀,后腐蚀 # pred = close_operation(pred, kernel_size=3) pred = open_operation(pred, kernel_size=3) pred = fill_region(pred) from grabcut import grabcut pred_grabcut = np.asarray( image_expand(pred, kernel_size=5), np.uint8) xs, ys = np.where(pred_grabcut == 1) print(np.shape(pred_grabcut)) if len(xs) == 0: pred_grabcut = np.zeros_like(whole_mask) print(np.min(pred_grabcut), np.max(pred_grabcut), np.sum(pred_grabcut)) else: min_xs = np.min(xs) max_xs = np.max(xs) min_ys = np.min(ys) max_ys = np.max(ys) pred_grabcut = grabcut( np.asarray(image_data * 255., np.uint8), [min_xs, min_ys, max_xs, max_ys]) pred_grabcut = np.asarray(pred_grabcut == 255, np.uint8) print(np.unique(pred_grabcut)) cv2.imwrite('./tmp/%d_gt.png' % slice_idx, np.asarray(whole_mask * 200, np.uint8)) cv2.imwrite('./tmp/%d_pred.png' % slice_idx, np.asarray(pred * 200, np.uint8)) cv2.imwrite('./tmp/%d_pred_grabcut.png' % slice_idx, np.asarray(pred_grabcut * 200, np.uint8)) else: pred = np.zeros_like(whole_mask) pred_grabcut = np.zeros_like(whole_mask) global_gt.append(whole_mask) case_gts.append(whole_mask) case_preds.append(pred) case_preds_grabcut.append(pred_grabcut) global_pred.append(pred) global_pred_grabcut.append(pred_grabcut) print '%d / %d: %s' % (slice_idx + 1, len( imgs), os.path.basename(nii_path)), np.shape( pixel_cls_scores), np.max(pixel_cls_scores), np.min( pixel_cls_scores), np.shape(pixel_recover_feature) del pixel_recover_feature, pixel_recover_feature_ms gc.collect() case_dice = dice(case_gts, case_preds) case_IoU = IoU(case_gts, case_preds) case_dice_grabcut = dice(case_gts, case_preds_grabcut) case_IoU_grabcut = IoU(case_gts, case_preds_grabcut) print('case dice: ', case_dice) print('case IoU: ', case_IoU) print('case dice grabcut: ', case_dice_grabcut) print('case IoU grabcut: ', case_IoU_grabcut) case_dices.append(case_dice) case_IoUs.append(case_IoU) case_dices_grabcut.append(case_dice_grabcut) case_IoUs_grabcut.append(case_IoU_grabcut) print 'global dice is ', dice(global_gt, global_pred) print 'global IoU is ', IoU(global_gt, global_pred) print('mean of case dice is ', np.mean(case_dices)) print('mean of case IoU is ', np.mean(case_IoUs)) print('mean of case dice (grabcut) is ', np.mean(case_dices_grabcut)) print('mean of case IoU (grabcut) is ', np.mean(case_IoUs_grabcut)) print('global dice (grabcut) is ', dice(global_gt, global_pred_grabcut)) print('global IoU (grabcut) is ', IoU(global_gt, global_pred_grabcut))
def test(): with tf.name_scope('test'): image = tf.placeholder(dtype=tf.int32, shape=[None, None, 3]) image_shape = tf.placeholder(dtype=tf.int32, shape=[ 3, ]) processed_image = segmentation_preprocessing.segmentation_preprocessing( image, None, out_shape=[256, 256], is_training=False) b_image = tf.expand_dims(processed_image, axis=0) net = UNet.UNet(b_image, None, is_training=False) global_step = slim.get_or_create_global_step() sess_config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True) if FLAGS.gpu_memory_fraction < 0: sess_config.gpu_options.allow_growth = True elif FLAGS.gpu_memory_fraction > 0: sess_config.gpu_options.per_process_gpu_memory_fraction = FLAGS.gpu_memory_fraction checkpoint_dir = util.io.get_dir(FLAGS.checkpoint_path) logdir = util.io.join_path( checkpoint_dir, 'test', FLAGS.dataset_name + '_' + FLAGS.dataset_split_name) # Variables to restore: moving avg. or normal weights. if FLAGS.using_moving_average: variable_averages = tf.train.ExponentialMovingAverage( FLAGS.moving_average_decay) variables_to_restore = variable_averages.variables_to_restore() variables_to_restore[global_step.op.name] = global_step else: variables_to_restore = slim.get_variables_to_restore() saver = tf.train.Saver(var_list=variables_to_restore) case_names = util.io.ls(FLAGS.dataset_dir) case_names.sort() checkpoint = FLAGS.checkpoint_path checkpoint_name = util.io.get_filename(str(checkpoint)) with tf.Session(config=sess_config) as sess: saver.restore(sess, checkpoint) for iter, case_name in enumerate(case_names): image_data = util.img.imread(glob( util.io.join_path(FLAGS.dataset_dir, case_name, 'images', '*.png'))[0], rgb=True) pixel_cls_scores, = sess.run([net.pixel_cls_scores], feed_dict={image: image_data}) print '%d/%d: %s' % (iter + 1, len(case_names), case_name), np.shape(pixel_cls_scores) pos_score = np.asarray(pixel_cls_scores > 0.5, np.uint8)[0, :, :, 1] pos_score = cv2.resize(pos_score, tuple(np.shape(image_data)[:2]), interpolation=cv2.INTER_NEAREST) cv2.imwrite(util.io.join_path(FLAGS.pred_path, case_name + '.png'), pos_score) cv2.imwrite( util.io.join_path(FLAGS.pred_vis_path, case_name + '.png'), np.asarray(pos_score * 200))