Ejemplo n.º 1
0
 def preprocessing_fn_segmentation(image, mask, out_shape):
     return segmentation_preprocessing.segmentation_preprocessing(
         image, mask, out_shape, is_training)
Ejemplo n.º 2
0
def evulate_dir_nii_weakly():
    '''
    新版本的评估代码
    :return:
    '''
    import os
    from metrics import dice, IoU
    from datasets.medicalImage import convertCase2PNGs, image_expand
    case_names = util.io.ls(FLAGS.dataset_dir)
    case_names.sort()

    restore_path = '/home/give/PycharmProjects/weakly_label_segmentation/logs/DSW/DSW-upsampling-True-1-False-True-True/model.ckpt-10450'

    with tf.name_scope('test'):
        image = tf.placeholder(dtype=tf.uint8, shape=[None, None, 3])
        input_shape_placeholder = tf.placeholder(tf.int32, shape=[2])
        processed_image = segmentation_preprocessing.segmentation_preprocessing(image, None, None,
                                                                                out_shape=[256, 256],
                                                                                is_training=False)

        b_image = tf.expand_dims(processed_image, axis=0)

        net = UNetBlocksMS.UNet(b_image, None, None, is_training=False, decoder=FLAGS.decoder,
                                update_center_flag=FLAGS.update_center,
                                batch_size=2, init_center_value=None, update_center_strategy=1,
                                full_annotation_flag=False,
                                output_shape_tensor=input_shape_placeholder)

        # print slim.get_variables_to_restore()
        global_step = slim.get_or_create_global_step()

    sess_config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True)
    if FLAGS.gpu_memory_fraction < 0:
        sess_config.gpu_options.allow_growth = True
    elif FLAGS.gpu_memory_fraction > 0:
        sess_config.gpu_options.per_process_gpu_memory_fraction = FLAGS.gpu_memory_fraction

    # Variables to restore: moving avg. or normal weights.
    if FLAGS.using_moving_average:
        variable_averages = tf.train.ExponentialMovingAverage(
            FLAGS.moving_average_decay)
        variables_to_restore = variable_averages.variables_to_restore()
        variables_to_restore[global_step.op.name] = global_step
    else:
        variables_to_restore = slim.get_variables_to_restore()

    saver = tf.train.Saver()

    checkpoint = restore_path
    # pixel_recovery_features = tf.image.resize_images(net.pixel_recovery_features, image_shape_placeholder)
    with tf.Session(config=sess_config) as sess:
        saver.restore(sess, checkpoint)

        global_gt = []
        global_pred = []
        global_pred_kmeans = []
        global_pred_centers = []

        case_dices = []
        case_IoUs = []
        case_dices_kmeans = [] # kmeans
        case_IoUs_kmeans = []  # kmeans
        case_dices_centers = [] # net centers
        case_IoUs_centers = [] # net centers
        for iter, case_name in enumerate(case_names):
            # image_data = util.img.imread(
            #     glob(util.io.join_path(FLAGS.dataset_dir, case_name, 'images', '*.png'))[0])
            # gt = util.img.imread(util.io.join_path(FLAGS.dataset_dir, case_name, 'weakly_label_whole_mask.png'))[:, :, 1]

            image_data = cv2.imread(
                glob(util.io.join_path(FLAGS.dataset_dir, case_name, 'images', '*.png'))[0])
            gt = cv2.imread(glob(util.io.join_path(FLAGS.dataset_dir, case_name, 'whole_mask.png'))[0])[
                   :, :, 1]
            gt = cv2.resize(gt, (256, 256), interpolation=cv2.INTER_NEAREST)
            # if case_name == '0402a81e75262469925ea893b6706183832e85324f7b1e08e634129f5d522cdd':
            #     gt = np.transpose(gt, axes=[1, 0])
            print(np.shape(image_data), np.shape(gt))
            recover_image, pixel_recover_feature, pixel_cls_scores, b_image_v, global_step_v, net_centers = sess.run(
                [net.pixel_recovery_logits, net.pixel_recovery_features, net.pixel_cls_scores, b_image, global_step, net.centers],
                feed_dict={
                    image: image_data,
                    input_shape_placeholder: [256, 256],
                })

            pred = np.asarray(pixel_cls_scores[0, :, :, 1] > 0.5, np.uint8)
            pixel_recover_feature = pixel_recover_feature[0]
            # pred = cv2.resize(pred, tuple(np.shape(image_data)[:2][::-1]))
            # pixel_recover_feature = cv2.resize(pixel_recover_feature, tuple(np.shape(image_data)[:2][::-1]))
            print(np.shape(pred), case_name)
            cv2.imwrite('./tmp/%s_gt.png' % (case_name), np.asarray(gt * 200, np.uint8))
            cv2.imwrite('./tmp/%s_img.png' % (case_name), np.asarray(cv2.resize(image_data, (256, 256)), np.uint8))
            cv2.imwrite('./tmp/%s_recover_img.png' % (case_name), np.asarray(recover_image[0] * 255., np.uint8))
            cv2.imwrite('./tmp/%s_pred.png' % (case_name), np.asarray(pred * 200, np.uint8))
            # 开操作 先腐蚀,后膨胀
            # 闭操作 先膨胀,后腐蚀
            # pred = close_operation(pred, kernel_size=3)
            # pred = open_operation(pred, kernel_size=3)
            # pred = fill_region(pred)

            # 再计算kmeans 的结果
            pred_kmeans = np.asarray(image_expand(pred, kernel_size=5), np.uint8)
            # pred_seg = image_expand(pred_seg, 5)
            print(np.shape(pred_kmeans), np.shape(gt), np.shape(pixel_recover_feature))
            pred_kmeans = cluster_postprocessing(pred_kmeans, gt, pixel_recover_feature, k=2)
            # pred_kmeans[image_data[:, :, 1] < (10./255.)] = 0
            # pred_kmeans = close_operation(pred_kmeans, kernel_size=5)
            pred_kmeans = fill_region(pred_kmeans)

            # 计算根据center得到的结果
            # pixel_recover_feature, net_centers
            # pred_centers = np.asarray(image_expand(pred, kernel_size=5), np.uint8)
            pred_centers = np.asarray(pred, np.uint8)
            pred_centers = net_center_posprocessing(pred_centers, centers=net_centers,
                                                    pixel_wise_feature=pixel_recover_feature, gt=gt)
            # pred_centers[image_data[:, :, 1] < (10. / 255.)] = 0
            # pred_centers = close_operation(pred_centers, kernel_size=5)
            pred_centers = fill_region(pred_centers)
            cv2.imwrite('./tmp/%s_pred_center.png' % (case_name), np.asarray(pred_centers * 200, np.uint8))
            cv2.imwrite('./tmp/%s_pred_kmeans.png' % (case_name), np.asarray(pred_kmeans * 200, np.uint8))

            global_gt.append(gt)
            global_pred.append(pred)
            global_pred_kmeans.append(pred_kmeans)
            global_pred_centers.append(pred_centers)

            case_dice = dice(gt, pred)
            case_IoU = IoU(gt, pred)
            case_dice_kmeans = dice(gt, pred_kmeans)
            case_IoU_kmeans = IoU(gt, pred_kmeans)
            case_dice_centers = dice(gt, pred_centers)
            case_IoU_centers = IoU(gt, pred_centers)
            print('case dice: ', case_dice)
            print('case IoU: ', case_IoU)
            print('case dice kmeans: ', case_dice_kmeans)
            print('case IoU kmeans: ', case_IoU_kmeans)
            print('case dice centers: ', case_dice_centers)
            print('case IoU centers: ', case_IoU_centers)
            case_dices.append(case_dice)
            case_IoUs.append(case_IoU)
            case_dices_kmeans.append(case_dice_kmeans)
            case_IoUs_kmeans.append(case_IoU_kmeans)
            case_dices_centers.append(case_dice_centers)
            case_IoUs_centers.append(case_IoU_centers)

        print 'global dice is ', dice(global_gt, global_pred)
        print 'global IoU is ', IoU(global_gt, global_pred)

        print('mean of case dice is ', np.mean(case_dices))
        print('mean of case IoU is ', np.mean(case_IoUs))

        print 'global dice (kmeans) is ', dice(global_gt, global_pred_kmeans)
        print 'global IoU (kmeans) is ', IoU(global_gt, global_pred_kmeans)

        print 'mean of case dice (kmeans) is ', np.mean(case_dices_kmeans)
        print 'mean of case IoU (kmenas) is ', np.mean(case_IoUs_kmeans)

        print 'global dice (centers) is ', dice(global_gt, global_pred_centers)
        print 'global IoU (centers) is ', IoU(global_gt, global_pred_centers)

        print 'mean of case dice (centers) is ', np.mean(case_dices_centers)
        print 'mean of case IoU (centers) is ', np.mean(case_IoUs_centers)
Ejemplo n.º 3
0
def test():
    with tf.name_scope('test'):
        image = tf.placeholder(dtype=tf.int32, shape = [None, None, 3])
        image_shape = tf.placeholder(dtype = tf.int32, shape = [3, ])
        processed_image = segmentation_preprocessing.segmentation_preprocessing(image, None, out_shape=[256, 256],
                                                                                is_training=False)
        b_image = tf.expand_dims(processed_image, axis = 0)
        net = UNet.UNet(b_image, None, is_training=False, decoder=FLAGS.decoder)
        global_step = slim.get_or_create_global_step()

    sess_config = tf.ConfigProto(log_device_placement = False, allow_soft_placement = True)
    if FLAGS.gpu_memory_fraction < 0:
        sess_config.gpu_options.allow_growth = True
    elif FLAGS.gpu_memory_fraction > 0:
        sess_config.gpu_options.per_process_gpu_memory_fraction = FLAGS.gpu_memory_fraction;
    
    checkpoint_dir = util.io.get_dir(FLAGS.checkpoint_path)
    logdir = util.io.join_path(checkpoint_dir, 'test', FLAGS.dataset_name + '_' +FLAGS.dataset_split_name)

    # Variables to restore: moving avg. or normal weights.
    if FLAGS.using_moving_average:
        variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay)
        variables_to_restore = variable_averages.variables_to_restore()
        variables_to_restore[global_step.op.name] = global_step
    else:
        variables_to_restore = slim.get_variables_to_restore()
    
    saver = tf.train.Saver(var_list = variables_to_restore)

    case_names = util.io.ls(FLAGS.dataset_dir)
    case_names.sort()
    
    checkpoint = FLAGS.checkpoint_path
    checkpoint_name = util.io.get_filename(str(checkpoint))
    IoUs = []
    dices = []
    with tf.Session(config = sess_config) as sess:
        saver.restore(sess, checkpoint)
        centers = sess.run(net.centers)

        for iter, case_name in enumerate(case_names):
            image_data = util.img.imread(
                glob(util.io.join_path(FLAGS.dataset_dir, case_name, 'images', '*.png'))[0], rgb=True)

            pixel_cls_scores, pixel_recovery_feature_map = sess.run(
                [net.pixel_cls_scores, net.pixel_recovery_features[-1]],
                feed_dict = {
                    image: image_data
            })
            print '%d/%d: %s'%(iter + 1, len(case_names), case_name), np.shape(pixel_cls_scores)
            pos_score = np.asarray(pixel_cls_scores > 0.5, np.uint8)[0, :, :, 1]
            pred = cv2.resize(pos_score, tuple(np.shape(image_data)[:2][::-1]), interpolation=cv2.INTER_NEAREST)
            gt = util.img.imread(util.io.join_path(FLAGS.dataset_dir, case_name, 'weakly_label_whole_mask.png'))[:, :, 1]
            intersection = np.sum(np.logical_and(gt != 0, pred != 0))
            union = np.sum(gt != 0) + np.sum(pred != 0)
            IoU = (1.0 * intersection) / (1.0 * union - 1.0 * intersection) * 100
            dice = (2.0 * intersection) / (1.0 * union) * 100
            IoUs.append(IoU)
            dices.append(dice)
            cv2.imwrite(util.io.join_path(FLAGS.pred_path, case_name + '.png'), pred)
            cv2.imwrite(util.io.join_path(FLAGS.pred_vis_path, case_name + '.png'),
                        np.asarray(pred * 200))
            assign_label = get_assign_label(centers, pixel_recovery_feature_map)[0]
            assign_label = assign_label * pos_score
            assign_label += 1
            cv2.imwrite(util.io.join_path(FLAGS.pred_assign_label_path, case_name + '.png'),
                        np.asarray(assign_label * 100, np.uint8))
    print('total mean of IoU is ', np.mean(IoUs))
    print('total mean of dice is ', np.mean(dices))
Ejemplo n.º 4
0
def generate_recovery_image_feature_map():
    with tf.name_scope('test'):
        image = tf.placeholder(dtype=tf.uint8, shape=[None, None, 3])
        image_shape_placeholder = tf.placeholder(tf.int32, shape=[2])
        processed_image = segmentation_preprocessing.segmentation_preprocessing(image, None,
                                                                                out_shape=[FLAGS.eval_image_width,
                                                                                           FLAGS.eval_image_height],
                                                                                is_training=False)
        b_image = tf.expand_dims(processed_image, axis=0)
        print('the decoder is ', FLAGS.decoder)
        net = UNet.UNet(b_image, None, is_training=False, decoder=FLAGS.decoder, update_center_flag=FLAGS.update_center,
                        batch_size=1)
        # print slim.get_variables_to_restore()
        global_step = slim.get_or_create_global_step()

    sess_config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True)
    if FLAGS.gpu_memory_fraction < 0:
        sess_config.gpu_options.allow_growth = True
    elif FLAGS.gpu_memory_fraction > 0:
        sess_config.gpu_options.per_process_gpu_memory_fraction = FLAGS.gpu_memory_fraction

    # Variables to restore: moving avg. or normal weights.
    if FLAGS.using_moving_average:
        variable_averages = tf.train.ExponentialMovingAverage(
            FLAGS.moving_average_decay)
        variables_to_restore = variable_averages.variables_to_restore()
        variables_to_restore[global_step.op.name] = global_step
    else:
        variables_to_restore = slim.get_variables_to_restore()

    saver = tf.train.Saver()

    case_names = util.io.ls(FLAGS.dataset_dir)
    case_names.sort()

    checkpoint = FLAGS.checkpoint_path
    checkpoint_name = util.io.get_filename(str(checkpoint))
    IoUs = []
    dices = []

    pixel_recovery_features = tf.image.resize_images(net.pixel_recovery_features, image_shape_placeholder)
    with tf.Session(config=sess_config) as sess:
        saver.restore(sess, checkpoint)
        centers = sess.run(net.centers)
        print('sum of centers is ', np.sum(centers))

        for iter, case_name in enumerate(case_names):
            image_data = util.img.imread(
                glob(util.io.join_path(FLAGS.dataset_dir, case_name, 'images', '*.png'))[0], rgb=True)
            mask = cv2.imread(glob(util.io.join_path(FLAGS.dataset_dir, case_name, 'whole_mask.png'))[0])[
                   :, :, 0]

            pixel_cls_scores, recovery_img, recovery_feature_map, b_image_v, global_step_v = sess.run(
                [net.pixel_cls_scores, net.pixel_recovery_value, pixel_recovery_features, b_image, global_step],
                feed_dict={
                    image: image_data,
                    image_shape_placeholder: np.shape(image_data)[:2]
                })
            print global_step_v
            print '%d / %d: %s' % (iter + 1, len(case_names), case_name), np.shape(pixel_cls_scores), np.max(
                pixel_cls_scores[:, :, :, 1]), np.min(pixel_cls_scores[:, :, :, 1]), np.shape(
                recovery_img), np.max(recovery_img), np.min(recovery_img), np.max(b_image_v), np.min(
                b_image_v), np.shape(b_image_v)
            print np.shape(recovery_feature_map), np.shape(mask)
            pred_vis_path = util.io.join_path(FLAGS.pred_vis_dir, case_name + '.png')
            pred_path = util.io.join_path(FLAGS.pred_dir, case_name + '.png')
            pos_score = np.asarray(pixel_cls_scores > 0.5, np.uint8)[0, :, :, 1]
            pred = cv2.resize(pos_score, tuple(np.shape(image_data)[:2][::-1]), interpolation=cv2.INTER_NEAREST)
            cv2.imwrite(pred_vis_path, np.asarray(pred * 200, np.uint8))
            cv2.imwrite(pred_path, np.asarray(pred, np.uint8))
            recovery_img_path = util.io.join_path(FLAGS.recovery_img_dir, case_name + '.png')
            cv2.imwrite(recovery_img_path, np.asarray(recovery_img[0] * 255, np.uint8))
            recovery_feature_map_path = util.io.join_path(FLAGS.recovery_feature_map_dir, case_name + '.npy')

            xs, ys = np.where(mask == 1)
            features = recovery_feature_map[0][xs, ys, :]
            print 'the size of feature map is ', np.shape(np.asarray(features, np.float32))

            np.save(recovery_feature_map_path, np.asarray(features, np.float32))
Ejemplo n.º 5
0
def evulate_dir_nii_weakly_new():
    '''
    新版本的评估代码
    :return:
    '''
    from metrics import dice, IoU
    from datasets.medicalImage import convertCase2PNGs, image_expand
    nii_dir = '/home/give/Documents/dataset/ISBI2017/Training_Batch_1'
    save_dir = '/home/give/Documents/dataset/ISBI2017/weakly_label_segmentation_V4/Batch_1/DLSC_0/niis'
    # restore_path = '/home/give/PycharmProjects/weakly_label_segmentation/logs/ISBI2017_V2/1s_agumentation_weakly-upsampling-2/model.ckpt-168090'

    restore_path = '/home/give/PycharmProjects/weakly_label_segmentation/logs/ISBI2017_V2/1s_agumentation_weakly_V3-upsampling-1/model.ckpt-167824'

    nii_parent_dir = os.path.dirname(nii_dir)
    with tf.name_scope('test'):
        image = tf.placeholder(dtype=tf.float32, shape=[None, None, 3])
        image_shape_placeholder = tf.placeholder(tf.int32, shape=[2])
        input_shape_placeholder = tf.placeholder(tf.int32, shape=[2])
        processed_image = segmentation_preprocessing.segmentation_preprocessing(
            image,
            None,
            None,
            out_shape=input_shape_placeholder,
            is_training=False)
        b_image = tf.expand_dims(processed_image, axis=0)

        net = UNetBlocksMS.UNet(b_image,
                                None,
                                None,
                                is_training=False,
                                decoder=FLAGS.decoder,
                                update_center_flag=FLAGS.update_center,
                                batch_size=2,
                                init_center_value=None,
                                update_center_strategy=1,
                                num_centers_k=FLAGS.num_centers_k,
                                full_annotation_flag=False,
                                output_shape_tensor=input_shape_placeholder)

        # print slim.get_variables_to_restore()
        global_step = slim.get_or_create_global_step()

    sess_config = tf.ConfigProto(log_device_placement=False,
                                 allow_soft_placement=True)
    if FLAGS.gpu_memory_fraction < 0:
        sess_config.gpu_options.allow_growth = True
    elif FLAGS.gpu_memory_fraction > 0:
        sess_config.gpu_options.per_process_gpu_memory_fraction = FLAGS.gpu_memory_fraction

    # Variables to restore: moving avg. or normal weights.
    if FLAGS.using_moving_average:
        variable_averages = tf.train.ExponentialMovingAverage(
            FLAGS.moving_average_decay)
        variables_to_restore = variable_averages.variables_to_restore()
        variables_to_restore[global_step.op.name] = global_step
    else:
        variables_to_restore = slim.get_variables_to_restore()

    saver = tf.train.Saver()

    nii_pathes = glob(os.path.join(nii_dir, 'volume-*.nii'))

    checkpoint = restore_path
    # pixel_recovery_features = tf.image.resize_images(net.pixel_recovery_features, image_shape_placeholder)
    with tf.Session(config=sess_config) as sess:
        saver.restore(sess, checkpoint)

        global_gt = []
        global_pred = []
        global_pred_grabcut = []

        case_dices = []
        case_IoUs = []
        case_dices_grabcut = []
        case_IoUs_grabcut = []

        for iter, nii_path in enumerate(nii_pathes):
            # if os.path.basename(nii_path) in ['volume-15.nii', 'volume-25.nii']:
            #     continue
            if os.path.basename(nii_path) != 'volume-0.nii':
                continue
            nii_path_basename = os.path.basename(nii_path)

            pred_dir = os.path.join(
                save_dir,
                nii_path_basename.split('.')[0].split('-')[1], 'pred')
            pred_vis_dir = os.path.join(
                save_dir,
                nii_path_basename.split('.')[0].split('-')[1], 'pred_vis')
            recovery_img_dir = os.path.join(
                save_dir,
                nii_path_basename.split('.')[0].split('-')[1], 'recovery_img')
            if not os.path.exists(pred_dir):
                os.makedirs(pred_dir)
            if not os.path.exists(pred_vis_dir):
                os.makedirs(pred_vis_dir)
            if not os.path.exists(recovery_img_dir):
                os.makedirs(recovery_img_dir)

            seg_path = os.path.join(
                nii_dir, 'segmentation-' +
                nii_path.split('.')[0].split('-')[1] + '.nii')

            case_preds = []
            case_gts = []
            case_preds_grabcut = []

            # case_recover_features = []
            print(nii_path, seg_path)
            imgs, tumor_masks, liver_masks, tumor_weak_masks = convertCase2PNGs(
                nii_path, seg_path, save_dir=None)
            print(len(imgs), len(tumor_masks), len(liver_masks),
                  len(tumor_masks))

            for slice_idx, (image_data, liver_mask, whole_mask) in enumerate(
                    zip(imgs, liver_masks, tumor_masks)):
                pixel_cls_scores_ms = []
                pixel_recover_feature_ms = []
                for single_scale in scales:
                    pixel_recover_feature, pixel_cls_scores, b_image_v, global_step_v, net_centers = sess.run(
                        [
                            net.pixel_recovery_features, net.pixel_cls_scores,
                            b_image, global_step, net.centers
                        ],
                        feed_dict={
                            image: image_data,
                            image_shape_placeholder: np.shape(image_data)[:2],
                            input_shape_placeholder: single_scale
                        })
                    pixel_cls_scores_ms.append(
                        cv2.resize(pixel_cls_scores[0, :, :, 1],
                                   tuple(np.shape(image_data)[:2][::-1])))
                    pixel_recover_feature_ms.append(pixel_recover_feature[0])
                    del pixel_recover_feature
                pixel_cls_scores = np.mean(pixel_cls_scores_ms, axis=0)
                pixel_recover_feature = np.mean(pixel_recover_feature_ms,
                                                axis=0)
                # case_recover_features.append(pixel_recover_feature)

                if np.sum(whole_mask) != 0:
                    pred = np.asarray(pixel_cls_scores > 0.6, np.uint8)
                    # 开操作 先腐蚀,后膨胀
                    # 闭操作 先膨胀,后腐蚀
                    # pred = close_operation(pred, kernel_size=3)
                    pred = open_operation(pred, kernel_size=3)
                    pred = fill_region(pred)
                    from grabcut import grabcut
                    pred_grabcut = np.asarray(
                        image_expand(pred, kernel_size=5), np.uint8)
                    xs, ys = np.where(pred_grabcut == 1)
                    print(np.shape(pred_grabcut))
                    if len(xs) == 0:
                        pred_grabcut = np.zeros_like(whole_mask)
                        print(np.min(pred_grabcut), np.max(pred_grabcut),
                              np.sum(pred_grabcut))
                    else:
                        min_xs = np.min(xs)
                        max_xs = np.max(xs)
                        min_ys = np.min(ys)
                        max_ys = np.max(ys)
                        pred_grabcut = grabcut(
                            np.asarray(image_data * 255., np.uint8),
                            [min_xs, min_ys, max_xs, max_ys])
                        pred_grabcut = np.asarray(pred_grabcut == 255,
                                                  np.uint8)
                        print(np.unique(pred_grabcut))
                        cv2.imwrite('./tmp/%d_gt.png' % slice_idx,
                                    np.asarray(whole_mask * 200, np.uint8))
                        cv2.imwrite('./tmp/%d_pred.png' % slice_idx,
                                    np.asarray(pred * 200, np.uint8))
                        cv2.imwrite('./tmp/%d_pred_grabcut.png' % slice_idx,
                                    np.asarray(pred_grabcut * 200, np.uint8))

                else:
                    pred = np.zeros_like(whole_mask)
                    pred_grabcut = np.zeros_like(whole_mask)

                global_gt.append(whole_mask)
                case_gts.append(whole_mask)
                case_preds.append(pred)
                case_preds_grabcut.append(pred_grabcut)
                global_pred.append(pred)
                global_pred_grabcut.append(pred_grabcut)

                print '%d / %d: %s' % (slice_idx + 1, len(
                    imgs), os.path.basename(nii_path)), np.shape(
                        pixel_cls_scores), np.max(pixel_cls_scores), np.min(
                            pixel_cls_scores), np.shape(pixel_recover_feature)
                del pixel_recover_feature, pixel_recover_feature_ms
                gc.collect()

            case_dice = dice(case_gts, case_preds)
            case_IoU = IoU(case_gts, case_preds)
            case_dice_grabcut = dice(case_gts, case_preds_grabcut)
            case_IoU_grabcut = IoU(case_gts, case_preds_grabcut)
            print('case dice: ', case_dice)
            print('case IoU: ', case_IoU)
            print('case dice grabcut: ', case_dice_grabcut)
            print('case IoU grabcut: ', case_IoU_grabcut)
            case_dices.append(case_dice)
            case_IoUs.append(case_IoU)
            case_dices_grabcut.append(case_dice_grabcut)
            case_IoUs_grabcut.append(case_IoU_grabcut)
        print 'global dice is ', dice(global_gt, global_pred)
        print 'global IoU is ', IoU(global_gt, global_pred)

        print('mean of case dice is ', np.mean(case_dices))
        print('mean of case IoU is ', np.mean(case_IoUs))

        print('mean of case dice (grabcut) is ', np.mean(case_dices_grabcut))
        print('mean of case IoU (grabcut) is ', np.mean(case_IoUs_grabcut))

        print('global dice (grabcut) is ', dice(global_gt,
                                                global_pred_grabcut))
        print('global IoU (grabcut) is ', IoU(global_gt, global_pred_grabcut))
Ejemplo n.º 6
0
def test():
    with tf.name_scope('test'):
        image = tf.placeholder(dtype=tf.int32, shape=[None, None, 3])
        image_shape = tf.placeholder(dtype=tf.int32, shape=[
            3,
        ])
        processed_image = segmentation_preprocessing.segmentation_preprocessing(
            image, None, out_shape=[256, 256], is_training=False)
        b_image = tf.expand_dims(processed_image, axis=0)
        net = UNet.UNet(b_image, None, is_training=False)
        global_step = slim.get_or_create_global_step()

    sess_config = tf.ConfigProto(log_device_placement=False,
                                 allow_soft_placement=True)
    if FLAGS.gpu_memory_fraction < 0:
        sess_config.gpu_options.allow_growth = True
    elif FLAGS.gpu_memory_fraction > 0:
        sess_config.gpu_options.per_process_gpu_memory_fraction = FLAGS.gpu_memory_fraction

    checkpoint_dir = util.io.get_dir(FLAGS.checkpoint_path)
    logdir = util.io.join_path(
        checkpoint_dir, 'test',
        FLAGS.dataset_name + '_' + FLAGS.dataset_split_name)

    # Variables to restore: moving avg. or normal weights.
    if FLAGS.using_moving_average:
        variable_averages = tf.train.ExponentialMovingAverage(
            FLAGS.moving_average_decay)
        variables_to_restore = variable_averages.variables_to_restore()
        variables_to_restore[global_step.op.name] = global_step
    else:
        variables_to_restore = slim.get_variables_to_restore()

    saver = tf.train.Saver(var_list=variables_to_restore)

    case_names = util.io.ls(FLAGS.dataset_dir)
    case_names.sort()

    checkpoint = FLAGS.checkpoint_path
    checkpoint_name = util.io.get_filename(str(checkpoint))

    with tf.Session(config=sess_config) as sess:
        saver.restore(sess, checkpoint)

        for iter, case_name in enumerate(case_names):
            image_data = util.img.imread(glob(
                util.io.join_path(FLAGS.dataset_dir, case_name, 'images',
                                  '*.png'))[0],
                                         rgb=True)
            pixel_cls_scores, = sess.run([net.pixel_cls_scores],
                                         feed_dict={image: image_data})
            print '%d/%d: %s' % (iter + 1, len(case_names),
                                 case_name), np.shape(pixel_cls_scores)
            pos_score = np.asarray(pixel_cls_scores > 0.5, np.uint8)[0, :, :,
                                                                     1]
            pos_score = cv2.resize(pos_score,
                                   tuple(np.shape(image_data)[:2]),
                                   interpolation=cv2.INTER_NEAREST)

            cv2.imwrite(util.io.join_path(FLAGS.pred_path, case_name + '.png'),
                        pos_score)
            cv2.imwrite(
                util.io.join_path(FLAGS.pred_vis_path, case_name + '.png'),
                np.asarray(pos_score * 200))