def calculate_metrics(mask1, mask2):
     true_positives = metric.obj_tpr(mask1, mask2)
     false_positives = metric.obj_fpr(mask1, mask2)
     dc = metric.dc(mask1, mask2)
     hd = metric.hd(mask1, mask2)
     precision = metric.precision(mask1, mask2)
     recall = metric.recall(mask1, mask2)
     ravd = metric.ravd(mask1, mask2)
     assd = metric.assd(mask1, mask2)
     asd = metric.asd(mask1, mask2)
     return true_positives, false_positives, dc, hd, precision, recall, ravd, assd, asd
 def calculate_metrics(mask1, mask2):
     true_positives = metric.obj_tpr(mask1, mask2)
     false_positives = metric.obj_fpr(mask1, mask2)
     dc = metric.dc(mask1, mask2)
     hd = metric.hd(mask1, mask2)
     precision = metric.precision(mask1, mask2)
     recall = metric.recall(mask1, mask2)
     ravd = metric.ravd(mask1, mask2)
     assd = metric.assd(mask1, mask2)
     asd = metric.asd(mask1, mask2)
     return true_positives, false_positives, dc, hd, precision, recall, ravd, assd, asd
コード例 #3
0
ファイル: metrics.py プロジェクト: MarEe0/FeTA-Spatial
def avg_surface_distance_symmetric(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, voxel_spacing=None, connectivity=1, **kwargs):

    if confusion_matrix is None:
        confusion_matrix = ConfusionMatrix(test, reference)

    test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()

    if test_empty or test_full or reference_empty or reference_full:
        if nan_for_nonexisting:
            return float("NaN")
        else:
            return 0

    test, reference = confusion_matrix.test, confusion_matrix.reference

    return metric.assd(test, reference, voxel_spacing, connectivity)
コード例 #4
0
def total_score(pred_path,gt_path, file_names, metric_name):
    print(pred_path)
    
    
    gt_to_pred = {k:k for k in [0,1]}
    list_labels = sorted(gt_to_pred.keys())
    score = dict()
    thre = 400
    score['names'] = []
    score['lesion'] = []
    
    
    for name in file_names:
        ground_truth = gt_path.format(name)
        ground_truth = os.path.expanduser(ground_truth)
        image_gt = nibabel.load(ground_truth)
        image_gt= nibabel.funcs.as_closest_canonical(image_gt).get_data()
        image_gt = image_gt.reshape(image_gt.shape[:3])
        
        
        pred = pred_path.format(name)
        pred = os.path.expanduser(pred)
        image_pred = nibabel.load(pred)
        affine = image_pred.affine
        voxel = [affine[0,0],affine[1,1],affine[2,2]]
        image_pred = image_pred.get_data()
        image_pred = image_pred.reshape(image_pred.shape[:3])


        score['names'].append(name)
        if metric_name=='assd':
            score['lesion'].append(metric.assd(image_gt,image_pred,voxel))
        elif metric_name=='rve':
            score['lesion'].append(metric.ravd(image_gt,image_pred))
        else:
            score['lesion'].append(metric.dc(image_gt,image_pred))


    print('Sample size: {}'.format(len(list(score.values())[0])))
    for label in score.keys():
        
        if label != 'names':
            print('Label: {}, {} mean: {}'.format(label, metric_name, round(np.mean(score[label]),2)))
            print('Label: {}, {} std: {}'.format(label, metric_name, round(np.std(score[label]),2)))
        

    return score
コード例 #5
0
def get_scores(pred, label, vxlspacing):
    volscores = {}

    volscores['dice'] = metric.dc(pred, label)
    volscores['jaccard'] = metric.binary.jc(pred, label)
    volscores['voe'] = 1. - volscores['jaccard']
    volscores['rvd'] = metric.ravd(label, pred)

    if np.count_nonzero(pred) == 0 or np.count_nonzero(label) == 0:
        volscores['assd'] = 0
        volscores['msd'] = 0
    else:
        # evalsurf = Surface(pred,label,physical_voxel_spacing = vxlspacing,mask_offset = [0.,0.,0.], reference_offset = [0.,0.,0.])
        # volscores['assd'] = evalsurf.get_average_symmetric_surface_distance()
        volscores['assd'] = metric.assd(label, pred, voxelspacing=vxlspacing)
        volscores['msd'] = metric.hd(label, pred, voxelspacing=vxlspacing)

    return volscores
コード例 #6
0
def calculate_validation_metrics(probas_pred,
                                 image_gt,
                                 class_labels=None,
                                 num_classes=5):
    classes = np.arange(probas_pred.shape[-1])
    # determine valid classes (those that actually appear in image_gt). Some images may miss some classes
    classes = [c for c in classes if np.sum(image_gt == c) != 0]
    image_pred = probas_pred.argmax(-1)
    assert image_gt.shape == image_pred.shape
    accuracy = np.sum(image_gt == image_pred) / float(image_pred.size)
    class_metrics = {}
    y_true = convert_seg_flat_to_binary_label_indicator_array(
        image_gt.ravel(), num_classes).astype(int)[:, classes]
    y_pred = probas_pred.transpose(3, 0, 1,
                                   2).reshape(num_classes,
                                              -1).transpose(1, 0)[:, classes]
    scores = roc_auc_score(y_true, y_pred, None)
    for i, c in enumerate(classes):
        true_positives = metric.obj_tpr(image_gt == c, image_pred == c)
        false_positives = metric.obj_fpr(image_gt == c, image_pred == c)
        dc = metric.dc(image_gt == c, image_pred == c)
        hd = metric.hd(image_gt == c, image_pred == c)
        precision = metric.precision(image_gt == c, image_pred == c)
        recall = metric.recall(image_gt == c, image_pred == c)
        ravd = metric.ravd(image_gt == c, image_pred == c)
        assd = metric.assd(image_gt == c, image_pred == c)
        asd = metric.asd(image_gt == c, image_pred == c)
        label = c
        if class_labels is not None and c in class_labels.keys():
            label = class_labels[c]
        class_metrics[label] = {
            'true_positives': true_positives,
            'false_positives': false_positives,
            'DICE\t\t': dc,
            'Hausdorff dist': hd,
            'precision\t': precision,
            'recall\t\t': recall,
            'rel abs vol diff': ravd,
            'avg surf dist symm': assd,
            'avg surf dist\t': asd,
            'roc_auc\t\t': scores[i]
        }
    return accuracy, class_metrics
コード例 #7
0
def compute_typical_metrics(seg_gt, seg_pred, labels):
    assert seg_gt.shape == seg_pred.shape
    mask_pred = np.zeros(seg_pred.shape, dtype=bool)
    mask_gt = np.zeros(seg_pred.shape, dtype=bool)

    for l in labels:
        mask_gt[seg_gt == l] = True
        mask_pred[seg_pred == l] = True

    vol_gt = np.sum(mask_gt)
    vol_pred = np.sum(mask_pred)

    try:
        cm = confusion_matrix(
            mask_pred.astype(int).ravel(),
            mask_gt.astype(int).ravel())
        TN = cm[0][0]
        FN = cm[0][1]
        FP = cm[1][0]
        TP = cm[1][1]
        precision = TP / float(TP + FP)
        recall = TP / float(TP + FN)
        fpr = FP / float(FP + TN)
        false_omission_rate = FN / float(FN + TN)
    except:
        precision = np.nan
        recall = np.nan
        fpr = np.nan
        false_omission_rate = np.nan

    try:
        dice = metric.dc(mask_pred, mask_gt)
        if np.sum(mask_gt) == 0:
            dice = np.nan
    except:
        dice = np.nan

    try:
        assd = metric.assd(mask_gt, mask_pred)
    except:
        assd = np.nan

    return precision, recall, fpr, false_omission_rate, dice, assd, vol_gt, vol_pred
コード例 #8
0
def main(FLAGS):
    # set GPU device to use
    os.environ['CUDA_VISIBLE_DEVICES'] = '3'

    # get the models to evaluate
    ckpts = glob(os.path.join(FLAGS.ckpt_dir, '*.h5'))

    # get data files and sort them
    image_files = os.listdir(FLAGS.test_X_dir)
    anno_files = os.listdir(FLAGS.test_y_dir)
    image_files.sort()
    anno_files.sort()

    for ckpt in ckpts:
        K.clear_session()

        # set model to load
        ckpt_name = os.path.basename(ckpt)

        # create a location to store evaluation metrics
        metrics = np.zeros((len(image_files), FLAGS.classes, 8))
        overall_accuracy = np.zeros((len(image_files), ))

        # create a file writer to store the metrics
        excel_name = os.path.splitext(os.path.basename(ckpt))[0] + '.xlsx'
        writer = pd.ExcelWriter(excel_name)

        model = load_model(ckpt)

        for i in range(len(image_files)):
            # define path to the test data
            test_path = os.path.join(FLAGS.test_X_dir, image_files[i])

            generator = FCN2DDatasetGenerator(
                test_path,
                batch_size=FLAGS.batch_size,
                subset='test',
                normalization=FLAGS.normalization,
                categorical_labels=True,
                num_classes=FLAGS.classes)

            # check if the images and annotations are the correct files
            print(image_files[i], anno_files[i])

            preds = model.predict_generator(generator.generate(),
                                            steps=len(generator))
            stamp = datetime.datetime.fromtimestamp(
                time.time()).strftime('date_%Y_%m_%d_time_%H_%M_%S')
            write_hdf5('fcn_predictions_' + stamp + '.h5', preds)

            pred_file = glob(os.path.join(FLAGS.predictions_temp_dir,
                                          '*.h5'))[0]
            pt_name = image_files[i].split('.')[0]
            new_name_raw = pt_name + ckpt_name
            new_file_raw = os.path.join(FLAGS.predictions_final_dir,
                                        new_name_raw)
            os.rename(pred_file, new_file_raw)

            ref = read_hdf5_multientry(
                os.path.join(FLAGS.test_y_dir, anno_files[i]))
            ref = np.squeeze(np.asarray(ref))

            preds = read_hdf5(new_file_raw)
            preds = np.argmax(preds, axis=-1)

            overall_accuracy[i] = skm.accuracy_score(ref.flatten(),
                                                     preds.flatten())
            for j in range(FLAGS.classes):
                organ_pred = (preds == j).astype(np.int64)
                organ_ref = (ref == j).astype(np.int64)
                if np.sum(organ_pred) == 0 or np.sum(organ_ref) == 0:
                    metrics[i, j, 0] = 0.
                    metrics[i, j, 1] = 0.
                    metrics[i, j, 2] = 1.
                    metrics[i, j, 3] = 0.
                    metrics[i, j, 4] = 0.
                    metrics[i, j, 5] = 0.
                    metrics[i, j, 6] = np.inf
                    metrics[i, j, 7] = np.inf
                else:
                    metrics[i, j, 0] = jaccard_index(organ_ref, organ_pred)
                    metrics[i, j, 1] = dice_similarity_coefficient(
                        organ_ref, organ_pred)
                    metrics[i, j, 2] = relative_volume_difference(
                        organ_ref, organ_pred)
                    metrics[i, j, 3] = precision(organ_ref, organ_pred)
                    metrics[i, j, 4] = recall(organ_ref, organ_pred)
                    metrics[i, j, 5] = matthews_correlation_coefficient(
                        organ_ref, organ_pred)
                    metrics[i, j, 6] = mpm.hd95(organ_pred, organ_ref)
                    metrics[i, j, 7] = mpm.assd(organ_pred, organ_ref)
            print(overall_accuracy[i])
            print(metrics[i])

        for k in range(metrics.shape[-1]):
            data = pd.DataFrame(metrics[:, :, k], columns=['bg', 'met'])
            data.to_excel(writer, sheet_name=str(k))
        acc = pd.DataFrame(overall_accuracy, columns=['acc'])
        acc.to_excel(writer, sheet_name='acc')
        writer.save()
コード例 #9
0
def wassd(x):
	try:
		val = assd(*x)
	except RuntimeError:
		val = numpy.inf
	return val
コード例 #10
0
ファイル: metric.py プロジェクト: NanYoMy/cmmas
def calculate_binary_hd(y_true, y_pred, thres=0.5, spacing=[1, 1, 1]):
    y_true = np.squeeze(y_true)
    y_pred = np.squeeze(y_pred)
    y_true = np.where(y_true > thres, 1, 0)
    y_pred = np.where(y_pred > thres, 1, 0)
    return assd(y_pred, y_true, spacing)
コード例 #11
0
ファイル: evaluator_perregion.py プロジェクト: ihdia/Palmira
    def process(self, inputs, outputs):
        for input, output in zip(inputs, outputs):
            self.count += len(output["instances"])
            gt_segm = self.annotations[input["file_name"]]["segm_per_region"]
            try:
                _ = output["instances"].pred_masks
            except AttributeError:
                continue
            pred_segm = downsample_points(output)
            doc_ahd = {cat: [] for cat in categories_list}
            doc_hd = {cat: [] for cat in categories_list}
            doc_hd95 = {cat: [] for cat in categories_list}
            doc_iou = {cat: [] for cat in categories_list}
            doc_acc = {cat: [] for cat in categories_list}
            for reg_type in range(len(categories_list)):
                gt, pred = gt_segm[reg_type], pred_segm[reg_type]

                # Both have points
                if len(gt) and len(pred):
                    """"Peformed using medpy"""
                    gt_mask = np.zeros((len(gt), input["height"], input["width"]), dtype=np.int8)
                    for i in range(len(gt)):
                        cv2.fillPoly(gt_mask[i], np.array([gt[i]]).astype(np.int32), 1)
                    pred_mask = np.zeros((len(pred), input["height"], input["width"]), dtype=np.int8)
                    for i in range(len(pred)):
                        cv2.fillPoly(pred_mask[i], np.array([pred[i]]).astype(np.int32), 1)
                    gt_mask = gt_mask.astype(np.uint8)
                    gt_mask = (gt_mask * 255).astype(np.uint8)
                    pred_mask = pred_mask.astype(np.uint8)
                    pred_mask = (pred_mask * 255).astype(np.uint8)

                    def compute_iou_and_accuracy(arrs, edge_mask1):
                        intersection = cv2.bitwise_and(arrs, edge_mask1)
                        union = cv2.bitwise_or(arrs, edge_mask1)
                        intersection_sum = np.sum(intersection)
                        union_sum = np.sum(union)
                        iou = (intersection_sum) / (union_sum)
                        total = np.sum(arrs)
                        correct_predictions = intersection_sum
                        accuracy = correct_predictions / total
                        # print(iou, accuracy)
                        return iou, accuracy

                    iou_hd_dict = np.empty((len(gt), len(pred), 3))
                    # IOU, HD95, AHD
                    for i, each_gt_instance in enumerate(gt_mask):
                        for j, each_pred_instance in enumerate(pred_mask):
                            inst_iou, _ = compute_iou_and_accuracy(gt_mask[i], pred_mask[j])
                            inst_hd95 = hd95(gt_mask[i], pred_mask[j])
                            inst_ahd = assd(gt_mask[i], pred_mask[j])
                            iou_hd_dict[i][j][0] = inst_iou
                            iou_hd_dict[i][j][1] = inst_hd95
                            iou_hd_dict[i][j][2] = inst_ahd
                    corr_matrix = np.empty((len(gt), 4))
                    corr_matrix[:, 0] = np.argmax(iou_hd_dict[:, :, 0], 1)
                    for i, each_gt_instance_metric in enumerate(iou_hd_dict):
                        corr_matrix[i, 1:] = each_gt_instance_metric[corr_matrix[i, 0].astype(np.int)]

                    gt_mask_copy = gt_mask.copy()
                    gt_mask_copy = np.repeat(gt_mask_copy[:, :, :, np.newaxis], 3, axis=3)

                    def bbox2(img):
                        rows = np.any(img, axis=1)
                        cols = np.any(img, axis=0)
                        rmin, rmax = np.where(rows)[0][[0, -1]]
                        cmin, cmax = np.where(cols)[0][[0, -1]]

                        return rmin, rmax, cmin, cmax

                    for i, each_gt_instance in enumerate(gt_mask):
                        rmin, rmax, cmin, cmax = bbox2(each_gt_instance)
                        pos = (int(cmin), int((rmin + rmax) / 2))
                        gt_mask_copy[i] = cv2.cvtColor(each_gt_instance, cv2.COLOR_GRAY2RGB)
                        cv2.putText(gt_mask_copy[i], f"{i}", pos, cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 0, 0), 2)

                    gt_image = gt_mask_copy.sum(axis=0).clip(0, 255).astype(np.uint8)
                    # x = plt.imshow(gt_image)
                    # plt.show()
                    pred_image = pred_mask.sum(axis=0).clip(0, 255)
                    pred_image_copy = pred_image.copy()
                    pred_image_copy = np.repeat(pred_image_copy[:, :, np.newaxis], 3, axis=2).astype(np.uint8)

                    for i, each_pred_instance in enumerate(pred_mask):
                        rmin, rmax, cmin, cmax = bbox2(each_pred_instance)
                        pos = (int(cmin), int((rmin + rmax) / 2))
                        # pred_image_copy[i] = cv2.cvtColor(each_pred_instance, cv2.COLOR_GRAY2RGB)
                        if i in corr_matrix[:, 0].astype(np.int):
                            if self.write_to_csv:
                                metrics = (corr_matrix[np.where(i == corr_matrix[:, 0]), 1:]).squeeze()
                                if metrics.ndim > 1:
                                    metrics = metrics[np.argmax(metrics[:, 0])]
                                gt_idx = np.where((metrics == corr_matrix[:, 1:]).all(axis=1))[0].item()
                                self.metrics_for_csv.append(
                                    {
                                        "region_name": f"{osp.splitext('_'.join(input['file_name'].split('/')[-3:]))[0]}_gt_{reg_type}_{gt_idx}",
                                        "iou": metrics[0].round(2),
                                        "ahd": metrics[2].round(2),
                                        "hd95": metrics[1].round(2),
                                    }
                                )
                                continue
                            text_metrics = np.array2string(
                                (corr_matrix[np.where(i == corr_matrix[:, 0]), 1:].round(4)).squeeze()
                            )
                            cv2.putText(
                                pred_image_copy,
                                f"{i} - {text_metrics}",
                                pos,
                                cv2.FONT_HERSHEY_SIMPLEX,
                                0.7,
                                (255, 0, 0),
                                2,
                            )
                        else:
                            cv2.putText(pred_image_copy, f"{i}", pos, cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 255, 0), 1)

                    # x = plt.imshow(pred_image_copy)
                    # plt.show()
                    if not self.write_to_csv:
                        save_path = "final_outputs/comparision/bmrcnn_masks_whew/"
                        import os

                        file_name = os.path.splitext("_".join(input["file_name"].split("/")[-3:]))[0]
                        cv2.imwrite(
                            save_path + f"{file_name}_gt_{reg_type}.jpg", cv2.cvtColor(gt_image, cv2.COLOR_RGB2BGR)
                        )
                        cv2.imwrite(
                            save_path + f"{file_name}_pred_{reg_type}.jpg",
                            cv2.cvtColor(pred_image_copy, cv2.COLOR_RGB2BGR),
                        )
                    # res_iou, res_accuracy = compute_iou_and_accuracy(pred_mask, gt_mask)
                    # res_ahd, res_hd, res_hd95 = assd(pred_mask, gt_mask), hd(pred_mask, gt_mask), hd95(pred_mask, gt_mask)
                    # self.ahd[categories_list[reg_type]].append(res_ahd)
                    # self.hd[categories_list[reg_type]].append(res_hd)
                    # self.hd95[categories_list[reg_type]].append(res_hd95)
                    # self.hd[categories_list[reg_type]].append(hd(pred_mask, gt_mask))
                    # self.iou[categories_list[reg_type]].append(res_iou)
                    # self.acc[categories_list[reg_type]].append(res_accuracy)
                    #
                    # doc_ahd[categories_list[reg_type]].append(res_ahd)
                    # doc_hd[categories_list[reg_type]].append(res_hd)
                    # doc_hd95[categories_list[reg_type]].append(res_hd95)
                    # doc_hd[categories_list[reg_type]].append(hd(pred_mask, gt_mask))
                    # doc_iou[categories_list[reg_type]].append(res_iou)
                    # doc_acc[categories_list[reg_type]].append(res_accuracy)
                # One has points
                # elif len(gt) ^ len(pred):
                #     total_area = 0
                #     for each_pred in pred:
                #         total_area += PolyArea(each_pred[:, 0], each_pred[:, 1])
                #     hd = total_area / 100
                #     self.ahd[categories_list[reg_type]].append(hd)
                #     self.hd[categories_list[reg_type]].append(hd)
                #     self.hd95[categories_list[reg_type]].append(hd)
                # self.iou[categories_list[reg_type]].append(0)
                # self.acc[categories_list[reg_type]].append(0)
                # Both Empty
                # elif len(gt) == 0 and len(pred) != 0:

                else:
                    # self.hd[categories_list[reg_type]].append(0)
                    pass
            # print("Over for doc")
            total_ahd = list()
            for l in doc_ahd.values():
                total_ahd.extend(l)
            doc_ahd = np.mean(total_ahd)

            total_hd = list()
            for l in doc_hd.values():
                total_hd.extend(l)
            doc_hd = np.mean(total_hd)

            total_hd95 = list()
            for l in doc_hd95.values():
                total_hd95.extend(l)
            doc_hd95 = np.mean(total_hd95)

            total_iou = list()
            for l in doc_iou.values():
                total_iou.extend(l)
            doc_iou = np.mean(total_iou)

            total_acc = list()
            for l in doc_acc.values():
                total_acc.extend(l)
            doc_acc = np.mean(total_acc)

            self.doc_wise[input["file_name"]] = {
                "AHD": doc_ahd,
                "IOU": doc_iou,
                "HD": doc_hd,
                "HD95": doc_hd95,
                "ACC": doc_acc,
            }
コード例 #12
0
def average_surface_distance(data1, data2, voxelspacing=None):
    data1, data2 = to_numpy(data1), to_numpy(data2)
    return assd(data1, data2, voxelspacing)
コード例 #13
0
ファイル: evaluator.py プロジェクト: ihdia/Palmira
    def process(self, inputs, outputs):
        for input, output in zip(inputs, outputs):
            self.count += len(output['instances'])
            gt_segm = self.annotations[input['file_name']]['segm_per_region']
            try:
                _ = output['instances'].pred_masks
            except AttributeError:
                continue
            pred_segm = downsample_points(output)
            doc_ahd = {cat: [] for cat in categories_list}
            doc_hd = {cat: [] for cat in categories_list}
            doc_hd95 = {cat: [] for cat in categories_list}
            doc_iou = {cat: [] for cat in categories_list}
            doc_acc = {cat: [] for cat in categories_list}
            for reg_type in range(len(categories_list)):
                gt, pred = gt_segm[reg_type], pred_segm[reg_type]

                # Both have points
                if len(gt) and len(pred):
                    gt_mask = np.zeros((input['height'], input['width']),
                                       dtype=np.int8)
                    for i in gt:
                        cv2.fillPoly(gt_mask,
                                     np.array([i]).astype(np.int32), 1)
                    pred_mask = np.zeros((input['height'], input['width']),
                                         dtype=np.int8)
                    for i in pred:
                        cv2.fillPoly(pred_mask,
                                     np.array([i]).astype(np.int32), 1)
                    gt_mask = gt_mask.astype(np.uint8)
                    gt_mask = (gt_mask * 255).astype(np.uint8)
                    pred_mask = pred_mask.astype(np.uint8)
                    pred_mask = (pred_mask * 255).astype(np.uint8)

                    def compute_iou_and_accuracy(arrs, edge_mask1):
                        intersection = cv2.bitwise_and(arrs, edge_mask1)
                        union = cv2.bitwise_or(arrs, edge_mask1)
                        intersection_sum = np.sum(intersection)
                        union_sum = np.sum(union)
                        iou = (intersection_sum) / (union_sum)
                        total = np.sum(arrs)
                        correct_predictions = intersection_sum
                        accuracy = correct_predictions / total
                        # print(iou, accuracy)
                        return iou, accuracy

                    res_iou, res_accuracy = compute_iou_and_accuracy(
                        pred_mask, gt_mask)
                    res_ahd, res_hd, res_hd95 = assd(pred_mask, gt_mask), hd(
                        pred_mask, gt_mask), hd95(pred_mask, gt_mask)
                    self.ahd[categories_list[reg_type]].append(res_ahd)
                    self.hd[categories_list[reg_type]].append(res_hd)
                    self.hd95[categories_list[reg_type]].append(res_hd95)
                    self.hd[categories_list[reg_type]].append(
                        hd(pred_mask, gt_mask))
                    self.iou[categories_list[reg_type]].append(res_iou)
                    self.acc[categories_list[reg_type]].append(res_accuracy)

                    doc_ahd[categories_list[reg_type]].append(res_ahd)
                    doc_hd[categories_list[reg_type]].append(res_hd)
                    doc_hd95[categories_list[reg_type]].append(res_hd95)
                    doc_hd[categories_list[reg_type]].append(
                        hd(pred_mask, gt_mask))
                    doc_iou[categories_list[reg_type]].append(res_iou)
                    doc_acc[categories_list[reg_type]].append(res_accuracy)
                # One has points
                # elif len(gt) ^ len(pred):
                #     total_area = 0
                #     for each_pred in pred:
                #         total_area += PolyArea(each_pred[:, 0], each_pred[:, 1])
                #     hd = total_area / 100
                #     self.ahd[categories_list[reg_type]].append(hd)
                #     self.hd[categories_list[reg_type]].append(hd)
                #     self.hd95[categories_list[reg_type]].append(hd)
                # self.iou[categories_list[reg_type]].append(0)
                # self.acc[categories_list[reg_type]].append(0)
                # Both Empty
                # elif len(gt) == 0 and len(pred) != 0:

                else:
                    # self.hd[categories_list[reg_type]].append(0)
                    pass
            # print("Over for doc")
            total_ahd = list()
            for l in doc_ahd.values():
                total_ahd.extend(l)
            doc_ahd = np.mean(total_ahd)

            total_hd = list()
            for l in doc_hd.values():
                total_hd.extend(l)
            doc_hd = np.mean(total_hd)

            total_hd95 = list()
            for l in doc_hd95.values():
                total_hd95.extend(l)
            doc_hd95 = np.mean(total_hd95)

            total_iou = list()
            for l in doc_iou.values():
                total_iou.extend(l)
            doc_iou = np.mean(total_iou)

            total_acc = list()
            for l in doc_acc.values():
                total_acc.extend(l)
            doc_acc = np.mean(total_acc)

            self.doc_wise[input['file_name']] = {
                "AHD": doc_ahd,
                "IOU": doc_iou,
                "HD": doc_hd,
                "HD95": doc_hd95,
                "ACC": doc_acc
            }
def main(model_path, exp_config, do_plots=False):

    # Get Data
    data_loader = data_switch(exp_config.data_identifier)
    data = data_loader(exp_config)

    # Make and restore vagan model
    segmenter_model = segmenter(exp_config=exp_config, data=data, fixed_batch_size=1)  # CRF model requires fixed batch size
    segmenter_model.load_weights(model_path, type='best_dice')

    # Run predictions in an endless loop
    dice_list = []
    assd_list = []
    hd_list = []

    for ii, batch in enumerate(data.test.iterate_batches(1)):

        if ii % 100 == 0:
            logging.info("Progress: %d" % ii)

        x, y = batch

        y_ = segmenter_model.predict(x)[0]

        per_lbl_dice = []
        per_lbl_assd = []
        per_lbl_hd = []
        per_pixel_preds = []
        per_pixel_gts = []

        if do_plots and not sys_config.running_on_gpu_host:
            fig = plt.figure()
            fig.add_subplot(131)
            plt.imshow(np.squeeze(x), cmap='gray')
            fig.add_subplot(132)
            plt.imshow(np.squeeze(y_))
            fig.add_subplot(133)
            plt.imshow(np.squeeze(y))
            plt.show()

        for lbl in range(exp_config.nlabels):

            binary_pred = (y_ == lbl) * 1
            binary_gt = (y == lbl) * 1

            if np.sum(binary_gt) == 0 and np.sum(binary_pred) == 0:
                per_lbl_dice.append(1)
                per_lbl_assd.append(0)
                per_lbl_hd.append(0)
            elif np.sum(binary_pred) > 0 and np.sum(binary_gt) == 0 or np.sum(binary_pred) == 0 and np.sum(binary_gt) > 0:
                logging.warning('Structure missing in either GT (x)or prediction. ASSD and HD will not be accurate.')
                per_lbl_dice.append(0)
                per_lbl_assd.append(1)
                per_lbl_hd.append(1)
            else:
                per_lbl_dice.append(dc(binary_pred, binary_gt))
                per_lbl_assd.append(assd(binary_pred, binary_gt))
                per_lbl_hd.append(hd(binary_pred, binary_gt))

        dice_list.append(per_lbl_dice)
        assd_list.append(per_lbl_assd)
        hd_list.append(per_lbl_hd)

        per_pixel_preds.append(y_.flatten())
        per_pixel_gts.append(y.flatten())

    dice_arr = np.asarray(dice_list)
    assd_arr = np.asarray(assd_list)
    hd_arr = np.asarray(hd_list)

    mean_per_lbl_dice = dice_arr.mean(axis=0)
    mean_per_lbl_assd = assd_arr.mean(axis=0)
    mean_per_lbl_hd = hd_arr.mean(axis=0)

    logging.info('Dice')
    logging.info(structures_dict)
    logging.info(mean_per_lbl_dice)
    logging.info(np.mean(mean_per_lbl_dice))
    logging.info('foreground mean: %f' % (np.mean(mean_per_lbl_dice[1:])))
    logging.info('ASSD')
    logging.info(structures_dict)
    logging.info(mean_per_lbl_assd)
    logging.info(np.mean(mean_per_lbl_assd))
    logging.info('HD')
    logging.info(structures_dict)
    logging.info(mean_per_lbl_hd)
    logging.info(np.mean(mean_per_lbl_hd))
コード例 #15
0
def main(FLAGS):
    # set GPU device to use
    os.environ['CUDA_VISIBLE_DEVICES'] = '3'

    # define the experiments
    encoders = FLAGS.encoders.split(',')
    losses = FLAGS.losses.split(',')
    alpha = FLAGS.loss_param1.split(',')
    experiments = [encoders, losses, alpha]
    print(experiments)

    # get data files and sort them
    image_files = os.listdir(FLAGS.test_X_dir)
    anno_files = os.listdir(FLAGS.test_y_dir)
    image_files.sort()
    anno_files.sort()

    for experiment in itertools.product(*experiments):
        # switch to activate training session
        do_eval = True

        # load the base configurations
        if experiment[0] == 'UNet':
            configs = load_config(
                os.path.join(FLAGS.base_configs_dir, 'unet2d.json'))
        elif experiment[0] == 'UNet3D':
            configs = load_config(
                os.path.join(FLAGS.base_configs_dir, 'unet3d.json'))
        elif experiment[0] == 'VGG16' or experiment[0] == 'VGG19':
            configs = load_config(
                os.path.join(FLAGS.base_configs_dir, 'vgg16_unet.json'))
        else:
            configs = load_config(
                os.path.join(FLAGS.base_configs_dir, 'xception_unet.json'))

        # set the path to the model (checkpoints are fine for this)
        ckpt_name = 'encoder_{}_loss_{}_alpha_{}_beta_{}_ckpt.h5'.format(
            experiment[0], experiment[1], experiment[2],
            1. - literal_eval(experiment[2]))
        configs['paths']['load_model'] = os.path.join(FLAGS.ckpt_dir,
                                                      ckpt_name)

        # switch to inference
        configs['config_file']['type_signal'] = 'Inference'

        # perform some preprocessing
        configs['preprocessing']['categorical_switch'] = 'True'
        configs['preprocessing']['minimum_image_intensity'] = '0.0'
        configs['preprocessing']['maximum_image_intensity'] = '2048.0'
        configs['preprocessing']['normalization_type'] = '{}'.format(
            FLAGS.normalization)

        # set some other configurations
        configs['training_configurations']['batch_size'] = '{}'.format(
            FLAGS.batch_size)
        configs['config_file']['input_shape'] = '({}, {}, {})'.format(
            FLAGS.height, FLAGS.width, FLAGS.channels)

        # ensure xentropy/jaccard/focal only used once per encoder
        if experiment[1] == 'sparse_categorical_crossentropy'\
                or experiment[1] == 'categorical_crossentropy'\
                or experiment[1] == 'jaccard'\
                or experiment[1] == 'focal':
            if experiment[1] == 'focal':
                configs['loss_function']['parameter1'] = '0.75'
                configs['loss_function']['parameter2'] = '2.0'
            if experiment[1] == 'jaccard':
                configs['loss_function']['parameter1'] = '100.0'
            if experiment[2] == '0.3':
                configs['loss_function']['loss'] = experiment[1]
            else:
                do_eval = False
        elif experiment[1] == 'tversky':
            configs['loss_function']['loss'] = experiment[1]
            configs['loss_function']['parameter1'] = experiment[2]
            configs['loss_function']['parameter2'] = str(
                1. - literal_eval(experiment[2]))
        else:
            do_eval = False

        # create a location to store evaluation metrics
        metrics = np.zeros((len(image_files), FLAGS.classes, 8))
        overall_accuracy = np.zeros((len(image_files), ))

        # create a file writer to store the metrics
        excel_name = '{}_{}_{}_{}_metrics.xlsx'.format(
            experiment[0], experiment[1], experiment[2],
            1. - literal_eval(experiment[2]))
        writer = pd.ExcelWriter(excel_name)

        for i in range(len(image_files)):
            K.clear_session()
            # define path to the test data
            configs['paths']['test_X'] = os.path.join(FLAGS.test_X_dir,
                                                      image_files[i])

            if do_eval is True:
                configs_lvl1, errors_lvl1, warnings_lvl1 = level_one_error_checking(
                    configs)

                if any(warnings_lvl1):
                    with open('errors.txt', 'a') as f:
                        for warning in warnings_lvl1:
                            f.write("%s\n" % warning)
                        f.close()
                        print('Level 1 warnings encountered.')
                        print(
                            "The following level 1 warnings were identified and corrected based on engine defaults:"
                        )
                        for warning in warnings_lvl1:
                            print(warning)

                if any(errors_lvl1):
                    print('Level 1 errors encountered.')
                    print(
                        "Please fix the level 1 errors below before continuing:"
                    )
                    for error in errors_lvl1:
                        print(error)
                else:
                    configs_lvl2, errors_lvl2, warnings_lvl2 = level_two_error_checking(
                        configs_lvl1)

                    if any(warnings_lvl2):
                        print('Level 2 warnings encountered.')
                        print(
                            "The following level 2 warnings were identified and corrected based on engine defaults:"
                        )
                        for warning in warnings_lvl2:
                            print(warning)

                    if any(errors_lvl2):
                        print('Level 2 errors encountered.')
                        print(
                            "Please fix the level 2 errors below before continuing:"
                        )
                        for error in errors_lvl2:
                            print(error)
                    else:
                        engine = Dlae(configs)
                        engine.run()
                        if any(engine.errors):
                            print('Level 3 errors encountered.')
                            print(
                                "Please fix the level 3 errors below before continuing:"
                            )
                            for error in engine.errors:
                                print(error)

                # check if the images and annotations are the correct files
                print(image_files[i], anno_files[i])

                pred_file = glob(
                    os.path.join(FLAGS.predictions_temp_dir, '*.h5'))[0]
                pt_name = image_files[i].split('.')[0]
                new_name_raw = pt_name + '_{}_{}_{}_{}_raw.h5'.format(
                    experiment[0], experiment[1], experiment[2],
                    1. - literal_eval(experiment[2]))
                new_file_raw = os.path.join(FLAGS.predictions_final_dir,
                                            new_name_raw)
                os.rename(pred_file, new_file_raw)

                ref = read_hdf5_multientry(
                    os.path.join(FLAGS.test_y_dir, anno_files[i]))
                ref = np.squeeze(np.asarray(ref))

                preds = read_hdf5(new_file_raw)
                if experiment[0] == 'UNet3D':
                    ref = np.transpose(ref, (1, 2, 0))

                    # stich the image back together first
                    sw = SlidingWindow(ref, [96, 96, 40], [128, 128, 48])
                    preds = sw.stitch_patches(preds, sw.window_corner_coords,
                                              [128, 128, 48], sw.img_shape,
                                              FLAGS.classes)
                    preds = np.argmax(preds, axis=-1)
                else:
                    preds = np.argmax(preds, axis=-1)

                overall_accuracy[i] = skm.accuracy_score(
                    ref.flatten(), preds.flatten())
                for j in range(FLAGS.classes):
                    organ_pred = (preds == j).astype(np.int64)
                    organ_ref = (ref == j).astype(np.int64)
                    if np.sum(organ_pred) == 0 or np.sum(organ_ref) == 0:
                        metrics[i, j, 0] = 0.
                        metrics[i, j, 1] = 0.
                        metrics[i, j, 2] = 1.
                        metrics[i, j, 3] = 0.
                        metrics[i, j, 4] = 0.
                        metrics[i, j, 5] = 0.
                        metrics[i, j, 6] = np.inf
                        metrics[i, j, 7] = np.inf
                    else:
                        metrics[i, j, 0] = jaccard_index(organ_ref, organ_pred)
                        metrics[i, j, 1] = dice_similarity_coefficient(
                            organ_ref, organ_pred)
                        metrics[i, j, 2] = relative_volume_difference(
                            organ_ref, organ_pred)
                        metrics[i, j, 3] = precision(organ_ref, organ_pred)
                        metrics[i, j, 4] = recall(organ_ref, organ_pred)
                        metrics[i, j, 5] = matthews_correlation_coefficient(
                            organ_ref, organ_pred)
                        metrics[i, j, 6] = mpm.hd95(organ_pred, organ_ref)
                        metrics[i, j, 7] = mpm.assd(organ_pred, organ_ref)
                print(overall_accuracy[i])
                print(metrics[i])

            else:
                pass

        if do_eval is True:
            for k in range(metrics.shape[-1]):
                data = pd.DataFrame(
                    metrics[:, :, k],
                    columns=['bg', 'pros', 'eus', 'sv', 'rect', 'blad'])
                data.to_excel(writer, sheet_name=str(k))
            acc = pd.DataFrame(overall_accuracy, columns=['acc'])
            acc.to_excel(writer, sheet_name='acc')
            writer.save()
コード例 #16
0
ファイル: main.py プロジェクト: joagzb/Melano-malo

# =============================================================================
#   Distancia media entre superficies
# =============================================================================

distanceASD = mdm.asd(cHull_binary,cHull_Mbinary,connectivity=1)

print('Distancia Media',distanceASD)



# =============================================================================
#     Distancia media simétrica entre superficies
# =============================================================================
distanceASSD = mdm.assd(cHull_binary,cHull_Mbinary,connectivity=1)

print('Distancia Media Simétrica',distanceASSD)


# =============================================================================
#     Distancia media entre superficie de objetos
# =============================================================================

distanceObjASD = mdm.obj_asd(cHull_binary,cHull_Mbinary,connectivity=1)

print('Distancia Media entre superficie de objetos: ',distanceObjASD)



# =============================================================================
コード例 #17
0
 
 for j_c in range(n_class):
     dice_c[i_c, j_c] = dk_seg.dice(batch_y[0, :, :, :, j_c] == 1, y_tr_pr_c == j_c)
     
 y_t_c = np.argmax( batch_y[0, :, :, :, :], axis=-1)
 
 #dk_aux.save_pred_thumbs(batch_x[0,:,:,:,0], y_t_c, y_tr_pr_c, False, i_c, i_eval, images_dir )
 
 '''if i_eval==0:
     save_pred_mhds(batch_x[0,:,:,:,0], y_t_c, y_tr_pr_c, False, i_c, i_eval)
 else:
     save_pred_mhds(None, None, y_tr_pr_c, False, i_c, i_eval)'''
 
 dice_c[i_c, n_class] = hd95(y_t_c, y_tr_pr_c)
 dice_c[i_c, n_class+1] = asd(y_t_c, y_tr_pr_c)
 dice_c[i_c, n_class+2] = assd(y_t_c, y_tr_pr_c)
 
 y_tr_pr_soft= y_tr_pr_sum[:,:,:,1]/(y_tr_pr_cnt+1e-10)
 #dk_aux.save_pred_soft_thumbs(batch_x[0,:,:,:,0], y_t_c, y_tr_pr_c, y_tr_pr_soft, False, i_c, i_eval, images_dir)
 
 error_mask= dk_aux.seg_2_anulus(y_t_c, radius= 2.0)
 
 plot_save_path= None
 ECE, MCE, ECE_curve= dk_aux.estimate_ECE_and_MCE(y_t_c, y_tr_pr_soft, plot_save_path=plot_save_path)
 dice_c[i_c, n_class+3]= ECE
 dice_c[i_c, n_class+4]= MCE
 
 plot_save_path= None
 ECE, MCE, ECE_curve= dk_aux.estimate_ECE_and_MCE_masked(y_t_c, y_tr_pr_soft, error_mask, plot_save_path=plot_save_path)
 dice_c[i_c, n_class+5]= ECE
 dice_c[i_c, n_class+6]= MCE
def main(input_folder,
         output_folder,
         model_path,
         exp_config,
         do_postprocessing=False,
         gt_exists=True):

    # Get Data
    data_loader = data_switch(exp_config.data_identifier)
    data = data_loader(exp_config)

    # Make and restore vagan model
    segmenter_model = segmenter(
        exp_config=exp_config, data=data,
        fixed_batch_size=1)  # CRF model requires fixed batch size
    segmenter_model.load_weights(model_path, type='best_dice')

    total_time = 0
    total_volumes = 0

    dice_list = []
    assd_list = []
    hd_list = []

    for folder in os.listdir(input_folder):

        folder_path = os.path.join(input_folder, folder)

        if os.path.isdir(folder_path):

            infos = {}
            for line in open(os.path.join(folder_path, 'Info.cfg')):
                label, value = line.split(':')
                infos[label] = value.rstrip('\n').lstrip(' ')

            patient_id = folder.lstrip('patient')

            if not int(patient_id) % 5 == 0:
                continue

            ED_frame = int(infos['ED'])
            ES_frame = int(infos['ES'])

            for file in glob.glob(
                    os.path.join(folder_path, 'patient???_frame??.nii.gz')):

                logging.info(' ----- Doing image: -------------------------')
                logging.info('Doing: %s' % file)
                logging.info(' --------------------------------------------')

                file_base = file.split('.nii.gz')[0]

                frame = int(file_base.split('frame')[-1])
                img, img_affine, img_header = utils.load_nii(file)
                img = utils.normalise_image(img)
                zooms = img_header.get_zooms()

                if gt_exists:
                    file_mask = file_base + '_gt.nii.gz'
                    mask, mask_affine, mask_header = utils.load_nii(file_mask)

                start_time = time.time()

                if exp_config.dimensionality_mode == '2D':

                    pixel_size = (img_header.structarr['pixdim'][1],
                                  img_header.structarr['pixdim'][2])
                    scale_vector = (pixel_size[0] /
                                    exp_config.target_resolution[0],
                                    pixel_size[1] /
                                    exp_config.target_resolution[1])

                    predictions = []

                    nx, ny = exp_config.image_size

                    for zz in range(img.shape[2]):

                        slice_img = np.squeeze(img[:, :, zz])
                        slice_rescaled = transform.rescale(slice_img,
                                                           scale_vector,
                                                           order=1,
                                                           preserve_range=True,
                                                           multichannel=False,
                                                           mode='constant')

                        x, y = slice_rescaled.shape

                        x_s = (x - nx) // 2
                        y_s = (y - ny) // 2
                        x_c = (nx - x) // 2
                        y_c = (ny - y) // 2

                        # Crop section of image for prediction
                        if x > nx and y > ny:
                            slice_cropped = slice_rescaled[x_s:x_s + nx,
                                                           y_s:y_s + ny]
                        else:
                            slice_cropped = np.zeros((nx, ny))
                            if x <= nx and y > ny:
                                slice_cropped[x_c:x_c +
                                              x, :] = slice_rescaled[:,
                                                                     y_s:y_s +
                                                                     ny]
                            elif x > nx and y <= ny:
                                slice_cropped[:, y_c:y_c +
                                              y] = slice_rescaled[x_s:x_s +
                                                                  nx, :]
                            else:
                                slice_cropped[x_c:x_c + x, y_c:y_c +
                                              y] = slice_rescaled[:, :]

                        # GET PREDICTION
                        network_input = np.float32(
                            np.tile(np.reshape(slice_cropped, (nx, ny, 1)),
                                    (1, 1, 1, 1)))
                        mask_out, softmax = segmenter_model.predict(
                            network_input)

                        prediction_cropped = np.squeeze(softmax[0, ...])

                        # ASSEMBLE BACK THE SLICES
                        slice_predictions = np.zeros(
                            (x, y, exp_config.nlabels))
                        # insert cropped region into original image again
                        if x > nx and y > ny:
                            slice_predictions[x_s:x_s + nx, y_s:y_s +
                                              ny, :] = prediction_cropped
                        else:
                            if x <= nx and y > ny:
                                slice_predictions[:, y_s:y_s +
                                                  ny, :] = prediction_cropped[
                                                      x_c:x_c + x, :, :]
                            elif x > nx and y <= ny:
                                slice_predictions[
                                    x_s:x_s +
                                    nx, :, :] = prediction_cropped[:, y_c:y_c +
                                                                   y, :]
                            else:
                                slice_predictions[:, :, :] = prediction_cropped[
                                    x_c:x_c + x, y_c:y_c + y, :]

                        # RESCALING ON THE LOGITS
                        if gt_exists:
                            prediction = transform.resize(
                                slice_predictions,
                                (mask.shape[0], mask.shape[1],
                                 exp_config.nlabels),
                                order=1,
                                preserve_range=True,
                                mode='constant')
                        else:  # This can occasionally lead to wrong volume size, therefore if gt_exists
                            # we use the gt mask size for resizing.
                            prediction = transform.rescale(
                                slice_predictions, (1.0 / scale_vector[0],
                                                    1.0 / scale_vector[1], 1),
                                order=1,
                                preserve_range=True,
                                multichannel=False,
                                mode='constant')

                        prediction = np.uint8(np.argmax(prediction, axis=-1))
                        # import matplotlib.pyplot as plt
                        # fig = plt.Figure()
                        # for ii in range(3):
                        #     plt.subplot(1, 3, ii + 1)
                        #     plt.imshow(np.squeeze(prediction))
                        # plt.show()

                        predictions.append(prediction)

                    prediction_arr = np.transpose(
                        np.asarray(predictions, dtype=np.uint8), (1, 2, 0))

                elif exp_config.dimensionality_mode == '3D':

                    nx, ny, nz = exp_config.image_size

                    pixel_size = (img_header.structarr['pixdim'][1],
                                  img_header.structarr['pixdim'][2],
                                  img_header.structarr['pixdim'][3])

                    scale_vector = (pixel_size[0] /
                                    exp_config.target_resolution[0],
                                    pixel_size[1] /
                                    exp_config.target_resolution[1],
                                    pixel_size[2] /
                                    exp_config.target_resolution[2])

                    vol_scaled = transform.rescale(img,
                                                   scale_vector,
                                                   order=1,
                                                   preserve_range=True,
                                                   multichannel=False,
                                                   mode='constant')

                    nz_max = exp_config.image_size[2]
                    slice_vol = np.zeros((nx, ny, nz_max), dtype=np.float32)

                    nz_curr = vol_scaled.shape[2]
                    stack_from = (nz_max - nz_curr) // 2
                    stack_counter = stack_from

                    x, y, z = vol_scaled.shape

                    x_s = (x - nx) // 2
                    y_s = (y - ny) // 2
                    x_c = (nx - x) // 2
                    y_c = (ny - y) // 2

                    for zz in range(nz_curr):

                        slice_rescaled = vol_scaled[:, :, zz]

                        if x > nx and y > ny:
                            slice_cropped = slice_rescaled[x_s:x_s + nx,
                                                           y_s:y_s + ny]
                        else:
                            slice_cropped = np.zeros((nx, ny))
                            if x <= nx and y > ny:
                                slice_cropped[x_c:x_c +
                                              x, :] = slice_rescaled[:,
                                                                     y_s:y_s +
                                                                     ny]
                            elif x > nx and y <= ny:
                                slice_cropped[:, y_c:y_c +
                                              y] = slice_rescaled[x_s:x_s +
                                                                  nx, :]

                            else:
                                slice_cropped[x_c:x_c + x, y_c:y_c +
                                              y] = slice_rescaled[:, :]

                        slice_vol[:, :, stack_counter] = slice_cropped
                        stack_counter += 1

                    stack_to = stack_counter

                    network_input = np.float32(
                        np.reshape(slice_vol, (1, nx, ny, nz_max, 1)))
                    start_time = time.time()
                    mask_out, softmax = segmenter_model.predict(network_input)
                    logging.info('Classified 3D: %f secs' %
                                 (time.time() - start_time))

                    prediction_nzs = mask_out[0, :, :, stack_from:
                                              stack_to]  # non-zero-slices

                    if not prediction_nzs.shape[2] == nz_curr:
                        raise ValueError('sizes mismatch')

                    # ASSEMBLE BACK THE SLICES
                    prediction_scaled = np.zeros(
                        vol_scaled.shape)  # last dim is for logits classes

                    # insert cropped region into original image again
                    if x > nx and y > ny:
                        prediction_scaled[x_s:x_s + nx,
                                          y_s:y_s + ny, :] = prediction_nzs
                    else:
                        if x <= nx and y > ny:
                            prediction_scaled[:, y_s:y_s +
                                              ny, :] = prediction_nzs[x_c:x_c +
                                                                      x, :, :]
                        elif x > nx and y <= ny:
                            prediction_scaled[
                                x_s:x_s +
                                nx, :, :] = prediction_nzs[:, y_c:y_c + y, :]
                        else:
                            prediction_scaled[:, :, :] = prediction_nzs[
                                x_c:x_c + x, y_c:y_c + y, :]

                    logging.info('Prediction_scaled mean %f' %
                                 (np.mean(prediction_scaled)))

                    prediction = transform.resize(
                        prediction_scaled,
                        (mask.shape[0], mask.shape[1], mask.shape[2], 1),
                        order=1,
                        preserve_range=True,
                        mode='constant')
                    prediction = np.argmax(prediction, axis=-1)
                    prediction_arr = np.asarray(prediction, dtype=np.uint8)

                # This is the same for 2D and 3D again
                if do_postprocessing:
                    prediction_arr = utils.keep_largest_connected_components(
                        prediction_arr)

                elapsed_time = time.time() - start_time
                total_time += elapsed_time
                total_volumes += 1

                logging.info('Evaluation of volume took %f secs.' %
                             elapsed_time)

                if frame == ED_frame:
                    frame_suffix = '_ED'
                elif frame == ES_frame:
                    frame_suffix = '_ES'
                else:
                    raise ValueError(
                        'Frame doesnt correspond to ED or ES. frame = %d, ED = %d, ES = %d'
                        % (frame, ED_frame, ES_frame))

                # Save prediced mask
                out_file_name = os.path.join(
                    output_folder, 'prediction',
                    'patient' + patient_id + frame_suffix + '.nii.gz')
                if gt_exists:
                    out_affine = mask_affine
                    out_header = mask_header
                else:
                    out_affine = img_affine
                    out_header = img_header

                logging.info('saving to: %s' % out_file_name)
                utils.save_nii(out_file_name, prediction_arr, out_affine,
                               out_header)

                # Save image data to the same folder for convenience
                image_file_name = os.path.join(
                    output_folder, 'image',
                    'patient' + patient_id + frame_suffix + '.nii.gz')
                logging.info('saving to: %s' % image_file_name)
                utils.save_nii(image_file_name, img, out_affine, out_header)

                if gt_exists:

                    # Save GT image
                    gt_file_name = os.path.join(
                        output_folder, 'ground_truth',
                        'patient' + patient_id + frame_suffix + '.nii.gz')
                    logging.info('saving to: %s' % gt_file_name)
                    utils.save_nii(gt_file_name, mask, out_affine, out_header)

                    # Save difference mask between predictions and ground truth
                    difference_mask = np.where(
                        np.abs(prediction_arr - mask) > 0, [1], [0])
                    difference_mask = np.asarray(difference_mask,
                                                 dtype=np.uint8)
                    diff_file_name = os.path.join(
                        output_folder, 'difference',
                        'patient' + patient_id + frame_suffix + '.nii.gz')
                    logging.info('saving to: %s' % diff_file_name)
                    utils.save_nii(diff_file_name, difference_mask, out_affine,
                                   out_header)

                # calculate metrics
                y_ = prediction_arr
                y = mask

                per_lbl_dice = []
                per_lbl_assd = []
                per_lbl_hd = []

                for lbl in [3, 1, 2]:  #range(exp_config.nlabels):

                    binary_pred = (y_ == lbl) * 1
                    binary_gt = (y == lbl) * 1

                    if np.sum(binary_gt) == 0 and np.sum(binary_pred) == 0:
                        per_lbl_dice.append(1)
                        per_lbl_assd.append(0)
                        per_lbl_hd.append(0)
                    elif np.sum(binary_pred) > 0 and np.sum(
                            binary_gt) == 0 or np.sum(
                                binary_pred) == 0 and np.sum(binary_gt) > 0:
                        logging.warning(
                            'Structure missing in either GT (x)or prediction. ASSD and HD will not be accurate.'
                        )
                        per_lbl_dice.append(0)
                        per_lbl_assd.append(1)
                        per_lbl_hd.append(1)
                    else:
                        per_lbl_dice.append(dc(binary_pred, binary_gt))
                        per_lbl_assd.append(
                            assd(binary_pred, binary_gt, voxelspacing=zooms))
                        per_lbl_hd.append(
                            hd(binary_pred, binary_gt, voxelspacing=zooms))

                dice_list.append(per_lbl_dice)
                assd_list.append(per_lbl_assd)
                hd_list.append(per_lbl_hd)

    logging.info('Average time per volume: %f' % (total_time / total_volumes))

    dice_arr = np.asarray(dice_list)
    assd_arr = np.asarray(assd_list)
    hd_arr = np.asarray(hd_list)

    mean_per_lbl_dice = dice_arr.mean(axis=0)
    mean_per_lbl_assd = assd_arr.mean(axis=0)
    mean_per_lbl_hd = hd_arr.mean(axis=0)

    logging.info('Dice')
    logging.info(mean_per_lbl_dice)
    logging.info(np.mean(mean_per_lbl_dice))
    logging.info('ASSD')
    logging.info(mean_per_lbl_assd)
    logging.info(np.mean(mean_per_lbl_assd))
    logging.info('HD')
    logging.info(mean_per_lbl_hd)
    logging.info(np.mean(mean_per_lbl_hd))