コード例 #1
0
def calculate_metric_percase(pred, gt, num_classes):
    "二分类、多分类的指标统计"
    if num_classes is None:
        num_classes = len(np.unique(gt))#注意:gt不是onehot编码
    print('np.unique(gt):',np.unique(gt))
    if num_classes==2:
        dice = metric.binary.dc(pred, gt)
        jc = metric.binary.jc(pred, gt)
        hd = metric.binary.hd95(pred, gt)
        asd = metric.binary.asd(pred, gt)
    elif num_classes>2:
        gt_onehot = to_categorical(gt, num_classes)
        pred_onehot = to_categorical(pred, num_classes)
        dice = []
        jc = []
        hd = []
        asd = []
        for k in range(num_classes):
            pred_k = pred_onehot[...,k]
            gt_k = gt_onehot[...,k]
            dice +=  [metric.dc(result=pred_k, reference=gt_k)]
            jc += [metric.jc(result=pred_k, reference=gt_k)]
            hd += [metric.hd95(result=pred_k, reference=gt_k)]
            asd += [metric.asd(result=pred_k, reference=gt_k)]
    else:
        raise ValueError("pred和gt不能是onehot编码")
    return dice, jc, hd, asd
コード例 #2
0
ファイル: metrics.py プロジェクト: mikiwang820/MIS
def housdorff_distance_95(logits, label):
    predis = F.softmax(logits, dim=1)
    n_classes = logits.shape[1]

    result = []
    for cls in range(1, n_classes):
        result.append(hd95(predis[:, cls, ...].squeeze(), label == cls))

    return result
def hausdorff(prediction: np.ndarray, reference: np.ndarray) -> float:
    try:
        return metric.hd95(prediction, reference)

    except Exception as e:
        print("Error: ", e)
        print(
            f"Prediction does not contain the same label as gt. "
            f"Pred labels {np.unique(prediction)} GT labels {np.unique(reference)}"
        )
        return 373
コード例 #4
0
ファイル: metrics.py プロジェクト: MarEe0/FeTA-Spatial
def hausdorff_distance_95(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, voxel_spacing=None, connectivity=1, **kwargs):

    if confusion_matrix is None:
        confusion_matrix = ConfusionMatrix(test, reference)

    test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()

    if test_empty or test_full or reference_empty or reference_full:
        if nan_for_nonexisting:
            return float("NaN")
        else:
            return 0

    test, reference = confusion_matrix.test, confusion_matrix.reference

    return metric.hd95(test, reference, voxel_spacing, connectivity)
コード例 #5
0
ファイル: Task082_BraTS_2020.py プロジェクト: zzsnow/nnUNet
def compute_BraTS_HD95(ref, pred):
    """
    ref and gt are binary integer numpy.ndarray s
    spacing is assumed to be (1, 1, 1)
    :param ref:
    :param pred:
    :return:
    """
    num_ref = np.sum(ref)
    num_pred = np.sum(pred)

    if num_ref == 0:
        if num_pred == 0:
            return 0
        else:
            return 373.12866
    elif num_pred == 0 and num_ref != 0:
        return 373.12866
    else:
        return hd95(pred, ref, (1, 1, 1))
コード例 #6
0
 batch_x = X_test[i_c:i_c + 1, :, :, :, :].copy()
 batch_y = Y_test[i_c:i_c + 1, :, :, :, :].copy()
 
 for j_c in range(n_class):
     dice_c[i_c, j_c] = dk_seg.dice(batch_y[0, :, :, :, j_c] == 1, y_tr_pr_c == j_c)
     
 y_t_c = np.argmax( batch_y[0, :, :, :, :], axis=-1)
 
 #dk_aux.save_pred_thumbs(batch_x[0,:,:,:,0], y_t_c, y_tr_pr_c, False, i_c, i_eval, images_dir )
 
 '''if i_eval==0:
     save_pred_mhds(batch_x[0,:,:,:,0], y_t_c, y_tr_pr_c, False, i_c, i_eval)
 else:
     save_pred_mhds(None, None, y_tr_pr_c, False, i_c, i_eval)'''
 
 dice_c[i_c, n_class] = hd95(y_t_c, y_tr_pr_c)
 dice_c[i_c, n_class+1] = asd(y_t_c, y_tr_pr_c)
 dice_c[i_c, n_class+2] = assd(y_t_c, y_tr_pr_c)
 
 y_tr_pr_soft= y_tr_pr_sum[:,:,:,1]/(y_tr_pr_cnt+1e-10)
 #dk_aux.save_pred_soft_thumbs(batch_x[0,:,:,:,0], y_t_c, y_tr_pr_c, y_tr_pr_soft, False, i_c, i_eval, images_dir)
 
 error_mask= dk_aux.seg_2_anulus(y_t_c, radius= 2.0)
 
 plot_save_path= None
 ECE, MCE, ECE_curve= dk_aux.estimate_ECE_and_MCE(y_t_c, y_tr_pr_soft, plot_save_path=plot_save_path)
 dice_c[i_c, n_class+3]= ECE
 dice_c[i_c, n_class+4]= MCE
 
 plot_save_path= None
 ECE, MCE, ECE_curve= dk_aux.estimate_ECE_and_MCE_masked(y_t_c, y_tr_pr_soft, error_mask, plot_save_path=plot_save_path)
コード例 #7
0
def main(FLAGS):
    # set GPU device to use
    os.environ['CUDA_VISIBLE_DEVICES'] = '3'

    # get the models to evaluate
    ckpts = glob(os.path.join(FLAGS.ckpt_dir, '*.h5'))

    # get data files and sort them
    image_files = os.listdir(FLAGS.test_X_dir)
    anno_files = os.listdir(FLAGS.test_y_dir)
    image_files.sort()
    anno_files.sort()

    for ckpt in ckpts:
        K.clear_session()

        # set model to load
        ckpt_name = os.path.basename(ckpt)

        # create a location to store evaluation metrics
        metrics = np.zeros((len(image_files), FLAGS.classes, 8))
        overall_accuracy = np.zeros((len(image_files), ))

        # create a file writer to store the metrics
        excel_name = os.path.splitext(os.path.basename(ckpt))[0] + '.xlsx'
        writer = pd.ExcelWriter(excel_name)

        model = load_model(ckpt)

        for i in range(len(image_files)):
            # define path to the test data
            test_path = os.path.join(FLAGS.test_X_dir, image_files[i])

            generator = FCN2DDatasetGenerator(
                test_path,
                batch_size=FLAGS.batch_size,
                subset='test',
                normalization=FLAGS.normalization,
                categorical_labels=True,
                num_classes=FLAGS.classes)

            # check if the images and annotations are the correct files
            print(image_files[i], anno_files[i])

            preds = model.predict_generator(generator.generate(),
                                            steps=len(generator))
            stamp = datetime.datetime.fromtimestamp(
                time.time()).strftime('date_%Y_%m_%d_time_%H_%M_%S')
            write_hdf5('fcn_predictions_' + stamp + '.h5', preds)

            pred_file = glob(os.path.join(FLAGS.predictions_temp_dir,
                                          '*.h5'))[0]
            pt_name = image_files[i].split('.')[0]
            new_name_raw = pt_name + ckpt_name
            new_file_raw = os.path.join(FLAGS.predictions_final_dir,
                                        new_name_raw)
            os.rename(pred_file, new_file_raw)

            ref = read_hdf5_multientry(
                os.path.join(FLAGS.test_y_dir, anno_files[i]))
            ref = np.squeeze(np.asarray(ref))

            preds = read_hdf5(new_file_raw)
            preds = np.argmax(preds, axis=-1)

            overall_accuracy[i] = skm.accuracy_score(ref.flatten(),
                                                     preds.flatten())
            for j in range(FLAGS.classes):
                organ_pred = (preds == j).astype(np.int64)
                organ_ref = (ref == j).astype(np.int64)
                if np.sum(organ_pred) == 0 or np.sum(organ_ref) == 0:
                    metrics[i, j, 0] = 0.
                    metrics[i, j, 1] = 0.
                    metrics[i, j, 2] = 1.
                    metrics[i, j, 3] = 0.
                    metrics[i, j, 4] = 0.
                    metrics[i, j, 5] = 0.
                    metrics[i, j, 6] = np.inf
                    metrics[i, j, 7] = np.inf
                else:
                    metrics[i, j, 0] = jaccard_index(organ_ref, organ_pred)
                    metrics[i, j, 1] = dice_similarity_coefficient(
                        organ_ref, organ_pred)
                    metrics[i, j, 2] = relative_volume_difference(
                        organ_ref, organ_pred)
                    metrics[i, j, 3] = precision(organ_ref, organ_pred)
                    metrics[i, j, 4] = recall(organ_ref, organ_pred)
                    metrics[i, j, 5] = matthews_correlation_coefficient(
                        organ_ref, organ_pred)
                    metrics[i, j, 6] = mpm.hd95(organ_pred, organ_ref)
                    metrics[i, j, 7] = mpm.assd(organ_pred, organ_ref)
            print(overall_accuracy[i])
            print(metrics[i])

        for k in range(metrics.shape[-1]):
            data = pd.DataFrame(metrics[:, :, k], columns=['bg', 'met'])
            data.to_excel(writer, sheet_name=str(k))
        acc = pd.DataFrame(overall_accuracy, columns=['acc'])
        acc.to_excel(writer, sheet_name='acc')
        writer.save()
コード例 #8
0
ファイル: evaluator_perregion.py プロジェクト: ihdia/Palmira
    def process(self, inputs, outputs):
        for input, output in zip(inputs, outputs):
            self.count += len(output["instances"])
            gt_segm = self.annotations[input["file_name"]]["segm_per_region"]
            try:
                _ = output["instances"].pred_masks
            except AttributeError:
                continue
            pred_segm = downsample_points(output)
            doc_ahd = {cat: [] for cat in categories_list}
            doc_hd = {cat: [] for cat in categories_list}
            doc_hd95 = {cat: [] for cat in categories_list}
            doc_iou = {cat: [] for cat in categories_list}
            doc_acc = {cat: [] for cat in categories_list}
            for reg_type in range(len(categories_list)):
                gt, pred = gt_segm[reg_type], pred_segm[reg_type]

                # Both have points
                if len(gt) and len(pred):
                    """"Peformed using medpy"""
                    gt_mask = np.zeros((len(gt), input["height"], input["width"]), dtype=np.int8)
                    for i in range(len(gt)):
                        cv2.fillPoly(gt_mask[i], np.array([gt[i]]).astype(np.int32), 1)
                    pred_mask = np.zeros((len(pred), input["height"], input["width"]), dtype=np.int8)
                    for i in range(len(pred)):
                        cv2.fillPoly(pred_mask[i], np.array([pred[i]]).astype(np.int32), 1)
                    gt_mask = gt_mask.astype(np.uint8)
                    gt_mask = (gt_mask * 255).astype(np.uint8)
                    pred_mask = pred_mask.astype(np.uint8)
                    pred_mask = (pred_mask * 255).astype(np.uint8)

                    def compute_iou_and_accuracy(arrs, edge_mask1):
                        intersection = cv2.bitwise_and(arrs, edge_mask1)
                        union = cv2.bitwise_or(arrs, edge_mask1)
                        intersection_sum = np.sum(intersection)
                        union_sum = np.sum(union)
                        iou = (intersection_sum) / (union_sum)
                        total = np.sum(arrs)
                        correct_predictions = intersection_sum
                        accuracy = correct_predictions / total
                        # print(iou, accuracy)
                        return iou, accuracy

                    iou_hd_dict = np.empty((len(gt), len(pred), 3))
                    # IOU, HD95, AHD
                    for i, each_gt_instance in enumerate(gt_mask):
                        for j, each_pred_instance in enumerate(pred_mask):
                            inst_iou, _ = compute_iou_and_accuracy(gt_mask[i], pred_mask[j])
                            inst_hd95 = hd95(gt_mask[i], pred_mask[j])
                            inst_ahd = assd(gt_mask[i], pred_mask[j])
                            iou_hd_dict[i][j][0] = inst_iou
                            iou_hd_dict[i][j][1] = inst_hd95
                            iou_hd_dict[i][j][2] = inst_ahd
                    corr_matrix = np.empty((len(gt), 4))
                    corr_matrix[:, 0] = np.argmax(iou_hd_dict[:, :, 0], 1)
                    for i, each_gt_instance_metric in enumerate(iou_hd_dict):
                        corr_matrix[i, 1:] = each_gt_instance_metric[corr_matrix[i, 0].astype(np.int)]

                    gt_mask_copy = gt_mask.copy()
                    gt_mask_copy = np.repeat(gt_mask_copy[:, :, :, np.newaxis], 3, axis=3)

                    def bbox2(img):
                        rows = np.any(img, axis=1)
                        cols = np.any(img, axis=0)
                        rmin, rmax = np.where(rows)[0][[0, -1]]
                        cmin, cmax = np.where(cols)[0][[0, -1]]

                        return rmin, rmax, cmin, cmax

                    for i, each_gt_instance in enumerate(gt_mask):
                        rmin, rmax, cmin, cmax = bbox2(each_gt_instance)
                        pos = (int(cmin), int((rmin + rmax) / 2))
                        gt_mask_copy[i] = cv2.cvtColor(each_gt_instance, cv2.COLOR_GRAY2RGB)
                        cv2.putText(gt_mask_copy[i], f"{i}", pos, cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 0, 0), 2)

                    gt_image = gt_mask_copy.sum(axis=0).clip(0, 255).astype(np.uint8)
                    # x = plt.imshow(gt_image)
                    # plt.show()
                    pred_image = pred_mask.sum(axis=0).clip(0, 255)
                    pred_image_copy = pred_image.copy()
                    pred_image_copy = np.repeat(pred_image_copy[:, :, np.newaxis], 3, axis=2).astype(np.uint8)

                    for i, each_pred_instance in enumerate(pred_mask):
                        rmin, rmax, cmin, cmax = bbox2(each_pred_instance)
                        pos = (int(cmin), int((rmin + rmax) / 2))
                        # pred_image_copy[i] = cv2.cvtColor(each_pred_instance, cv2.COLOR_GRAY2RGB)
                        if i in corr_matrix[:, 0].astype(np.int):
                            if self.write_to_csv:
                                metrics = (corr_matrix[np.where(i == corr_matrix[:, 0]), 1:]).squeeze()
                                if metrics.ndim > 1:
                                    metrics = metrics[np.argmax(metrics[:, 0])]
                                gt_idx = np.where((metrics == corr_matrix[:, 1:]).all(axis=1))[0].item()
                                self.metrics_for_csv.append(
                                    {
                                        "region_name": f"{osp.splitext('_'.join(input['file_name'].split('/')[-3:]))[0]}_gt_{reg_type}_{gt_idx}",
                                        "iou": metrics[0].round(2),
                                        "ahd": metrics[2].round(2),
                                        "hd95": metrics[1].round(2),
                                    }
                                )
                                continue
                            text_metrics = np.array2string(
                                (corr_matrix[np.where(i == corr_matrix[:, 0]), 1:].round(4)).squeeze()
                            )
                            cv2.putText(
                                pred_image_copy,
                                f"{i} - {text_metrics}",
                                pos,
                                cv2.FONT_HERSHEY_SIMPLEX,
                                0.7,
                                (255, 0, 0),
                                2,
                            )
                        else:
                            cv2.putText(pred_image_copy, f"{i}", pos, cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 255, 0), 1)

                    # x = plt.imshow(pred_image_copy)
                    # plt.show()
                    if not self.write_to_csv:
                        save_path = "final_outputs/comparision/bmrcnn_masks_whew/"
                        import os

                        file_name = os.path.splitext("_".join(input["file_name"].split("/")[-3:]))[0]
                        cv2.imwrite(
                            save_path + f"{file_name}_gt_{reg_type}.jpg", cv2.cvtColor(gt_image, cv2.COLOR_RGB2BGR)
                        )
                        cv2.imwrite(
                            save_path + f"{file_name}_pred_{reg_type}.jpg",
                            cv2.cvtColor(pred_image_copy, cv2.COLOR_RGB2BGR),
                        )
                    # res_iou, res_accuracy = compute_iou_and_accuracy(pred_mask, gt_mask)
                    # res_ahd, res_hd, res_hd95 = assd(pred_mask, gt_mask), hd(pred_mask, gt_mask), hd95(pred_mask, gt_mask)
                    # self.ahd[categories_list[reg_type]].append(res_ahd)
                    # self.hd[categories_list[reg_type]].append(res_hd)
                    # self.hd95[categories_list[reg_type]].append(res_hd95)
                    # self.hd[categories_list[reg_type]].append(hd(pred_mask, gt_mask))
                    # self.iou[categories_list[reg_type]].append(res_iou)
                    # self.acc[categories_list[reg_type]].append(res_accuracy)
                    #
                    # doc_ahd[categories_list[reg_type]].append(res_ahd)
                    # doc_hd[categories_list[reg_type]].append(res_hd)
                    # doc_hd95[categories_list[reg_type]].append(res_hd95)
                    # doc_hd[categories_list[reg_type]].append(hd(pred_mask, gt_mask))
                    # doc_iou[categories_list[reg_type]].append(res_iou)
                    # doc_acc[categories_list[reg_type]].append(res_accuracy)
                # One has points
                # elif len(gt) ^ len(pred):
                #     total_area = 0
                #     for each_pred in pred:
                #         total_area += PolyArea(each_pred[:, 0], each_pred[:, 1])
                #     hd = total_area / 100
                #     self.ahd[categories_list[reg_type]].append(hd)
                #     self.hd[categories_list[reg_type]].append(hd)
                #     self.hd95[categories_list[reg_type]].append(hd)
                # self.iou[categories_list[reg_type]].append(0)
                # self.acc[categories_list[reg_type]].append(0)
                # Both Empty
                # elif len(gt) == 0 and len(pred) != 0:

                else:
                    # self.hd[categories_list[reg_type]].append(0)
                    pass
            # print("Over for doc")
            total_ahd = list()
            for l in doc_ahd.values():
                total_ahd.extend(l)
            doc_ahd = np.mean(total_ahd)

            total_hd = list()
            for l in doc_hd.values():
                total_hd.extend(l)
            doc_hd = np.mean(total_hd)

            total_hd95 = list()
            for l in doc_hd95.values():
                total_hd95.extend(l)
            doc_hd95 = np.mean(total_hd95)

            total_iou = list()
            for l in doc_iou.values():
                total_iou.extend(l)
            doc_iou = np.mean(total_iou)

            total_acc = list()
            for l in doc_acc.values():
                total_acc.extend(l)
            doc_acc = np.mean(total_acc)

            self.doc_wise[input["file_name"]] = {
                "AHD": doc_ahd,
                "IOU": doc_iou,
                "HD": doc_hd,
                "HD95": doc_hd95,
                "ACC": doc_acc,
            }
コード例 #9
0
ファイル: evaluator.py プロジェクト: ihdia/Palmira
    def process(self, inputs, outputs):
        for input, output in zip(inputs, outputs):
            self.count += len(output['instances'])
            gt_segm = self.annotations[input['file_name']]['segm_per_region']
            try:
                _ = output['instances'].pred_masks
            except AttributeError:
                continue
            pred_segm = downsample_points(output)
            doc_ahd = {cat: [] for cat in categories_list}
            doc_hd = {cat: [] for cat in categories_list}
            doc_hd95 = {cat: [] for cat in categories_list}
            doc_iou = {cat: [] for cat in categories_list}
            doc_acc = {cat: [] for cat in categories_list}
            for reg_type in range(len(categories_list)):
                gt, pred = gt_segm[reg_type], pred_segm[reg_type]

                # Both have points
                if len(gt) and len(pred):
                    gt_mask = np.zeros((input['height'], input['width']),
                                       dtype=np.int8)
                    for i in gt:
                        cv2.fillPoly(gt_mask,
                                     np.array([i]).astype(np.int32), 1)
                    pred_mask = np.zeros((input['height'], input['width']),
                                         dtype=np.int8)
                    for i in pred:
                        cv2.fillPoly(pred_mask,
                                     np.array([i]).astype(np.int32), 1)
                    gt_mask = gt_mask.astype(np.uint8)
                    gt_mask = (gt_mask * 255).astype(np.uint8)
                    pred_mask = pred_mask.astype(np.uint8)
                    pred_mask = (pred_mask * 255).astype(np.uint8)

                    def compute_iou_and_accuracy(arrs, edge_mask1):
                        intersection = cv2.bitwise_and(arrs, edge_mask1)
                        union = cv2.bitwise_or(arrs, edge_mask1)
                        intersection_sum = np.sum(intersection)
                        union_sum = np.sum(union)
                        iou = (intersection_sum) / (union_sum)
                        total = np.sum(arrs)
                        correct_predictions = intersection_sum
                        accuracy = correct_predictions / total
                        # print(iou, accuracy)
                        return iou, accuracy

                    res_iou, res_accuracy = compute_iou_and_accuracy(
                        pred_mask, gt_mask)
                    res_ahd, res_hd, res_hd95 = assd(pred_mask, gt_mask), hd(
                        pred_mask, gt_mask), hd95(pred_mask, gt_mask)
                    self.ahd[categories_list[reg_type]].append(res_ahd)
                    self.hd[categories_list[reg_type]].append(res_hd)
                    self.hd95[categories_list[reg_type]].append(res_hd95)
                    self.hd[categories_list[reg_type]].append(
                        hd(pred_mask, gt_mask))
                    self.iou[categories_list[reg_type]].append(res_iou)
                    self.acc[categories_list[reg_type]].append(res_accuracy)

                    doc_ahd[categories_list[reg_type]].append(res_ahd)
                    doc_hd[categories_list[reg_type]].append(res_hd)
                    doc_hd95[categories_list[reg_type]].append(res_hd95)
                    doc_hd[categories_list[reg_type]].append(
                        hd(pred_mask, gt_mask))
                    doc_iou[categories_list[reg_type]].append(res_iou)
                    doc_acc[categories_list[reg_type]].append(res_accuracy)
                # One has points
                # elif len(gt) ^ len(pred):
                #     total_area = 0
                #     for each_pred in pred:
                #         total_area += PolyArea(each_pred[:, 0], each_pred[:, 1])
                #     hd = total_area / 100
                #     self.ahd[categories_list[reg_type]].append(hd)
                #     self.hd[categories_list[reg_type]].append(hd)
                #     self.hd95[categories_list[reg_type]].append(hd)
                # self.iou[categories_list[reg_type]].append(0)
                # self.acc[categories_list[reg_type]].append(0)
                # Both Empty
                # elif len(gt) == 0 and len(pred) != 0:

                else:
                    # self.hd[categories_list[reg_type]].append(0)
                    pass
            # print("Over for doc")
            total_ahd = list()
            for l in doc_ahd.values():
                total_ahd.extend(l)
            doc_ahd = np.mean(total_ahd)

            total_hd = list()
            for l in doc_hd.values():
                total_hd.extend(l)
            doc_hd = np.mean(total_hd)

            total_hd95 = list()
            for l in doc_hd95.values():
                total_hd95.extend(l)
            doc_hd95 = np.mean(total_hd95)

            total_iou = list()
            for l in doc_iou.values():
                total_iou.extend(l)
            doc_iou = np.mean(total_iou)

            total_acc = list()
            for l in doc_acc.values():
                total_acc.extend(l)
            doc_acc = np.mean(total_acc)

            self.doc_wise[input['file_name']] = {
                "AHD": doc_ahd,
                "IOU": doc_iou,
                "HD": doc_hd,
                "HD95": doc_hd95,
                "ACC": doc_acc
            }
コード例 #10
0
def main(FLAGS):
    # set GPU device to use
    os.environ['CUDA_VISIBLE_DEVICES'] = '3'

    # define the experiments
    encoders = FLAGS.encoders.split(',')
    losses = FLAGS.losses.split(',')
    alpha = FLAGS.loss_param1.split(',')
    experiments = [encoders, losses, alpha]
    print(experiments)

    # get data files and sort them
    image_files = os.listdir(FLAGS.test_X_dir)
    anno_files = os.listdir(FLAGS.test_y_dir)
    image_files.sort()
    anno_files.sort()

    for experiment in itertools.product(*experiments):
        # switch to activate training session
        do_eval = True

        # load the base configurations
        if experiment[0] == 'UNet':
            configs = load_config(
                os.path.join(FLAGS.base_configs_dir, 'unet2d.json'))
        elif experiment[0] == 'UNet3D':
            configs = load_config(
                os.path.join(FLAGS.base_configs_dir, 'unet3d.json'))
        elif experiment[0] == 'VGG16' or experiment[0] == 'VGG19':
            configs = load_config(
                os.path.join(FLAGS.base_configs_dir, 'vgg16_unet.json'))
        else:
            configs = load_config(
                os.path.join(FLAGS.base_configs_dir, 'xception_unet.json'))

        # set the path to the model (checkpoints are fine for this)
        ckpt_name = 'encoder_{}_loss_{}_alpha_{}_beta_{}_ckpt.h5'.format(
            experiment[0], experiment[1], experiment[2],
            1. - literal_eval(experiment[2]))
        configs['paths']['load_model'] = os.path.join(FLAGS.ckpt_dir,
                                                      ckpt_name)

        # switch to inference
        configs['config_file']['type_signal'] = 'Inference'

        # perform some preprocessing
        configs['preprocessing']['categorical_switch'] = 'True'
        configs['preprocessing']['minimum_image_intensity'] = '0.0'
        configs['preprocessing']['maximum_image_intensity'] = '2048.0'
        configs['preprocessing']['normalization_type'] = '{}'.format(
            FLAGS.normalization)

        # set some other configurations
        configs['training_configurations']['batch_size'] = '{}'.format(
            FLAGS.batch_size)
        configs['config_file']['input_shape'] = '({}, {}, {})'.format(
            FLAGS.height, FLAGS.width, FLAGS.channels)

        # ensure xentropy/jaccard/focal only used once per encoder
        if experiment[1] == 'sparse_categorical_crossentropy'\
                or experiment[1] == 'categorical_crossentropy'\
                or experiment[1] == 'jaccard'\
                or experiment[1] == 'focal':
            if experiment[1] == 'focal':
                configs['loss_function']['parameter1'] = '0.75'
                configs['loss_function']['parameter2'] = '2.0'
            if experiment[1] == 'jaccard':
                configs['loss_function']['parameter1'] = '100.0'
            if experiment[2] == '0.3':
                configs['loss_function']['loss'] = experiment[1]
            else:
                do_eval = False
        elif experiment[1] == 'tversky':
            configs['loss_function']['loss'] = experiment[1]
            configs['loss_function']['parameter1'] = experiment[2]
            configs['loss_function']['parameter2'] = str(
                1. - literal_eval(experiment[2]))
        else:
            do_eval = False

        # create a location to store evaluation metrics
        metrics = np.zeros((len(image_files), FLAGS.classes, 8))
        overall_accuracy = np.zeros((len(image_files), ))

        # create a file writer to store the metrics
        excel_name = '{}_{}_{}_{}_metrics.xlsx'.format(
            experiment[0], experiment[1], experiment[2],
            1. - literal_eval(experiment[2]))
        writer = pd.ExcelWriter(excel_name)

        for i in range(len(image_files)):
            K.clear_session()
            # define path to the test data
            configs['paths']['test_X'] = os.path.join(FLAGS.test_X_dir,
                                                      image_files[i])

            if do_eval is True:
                configs_lvl1, errors_lvl1, warnings_lvl1 = level_one_error_checking(
                    configs)

                if any(warnings_lvl1):
                    with open('errors.txt', 'a') as f:
                        for warning in warnings_lvl1:
                            f.write("%s\n" % warning)
                        f.close()
                        print('Level 1 warnings encountered.')
                        print(
                            "The following level 1 warnings were identified and corrected based on engine defaults:"
                        )
                        for warning in warnings_lvl1:
                            print(warning)

                if any(errors_lvl1):
                    print('Level 1 errors encountered.')
                    print(
                        "Please fix the level 1 errors below before continuing:"
                    )
                    for error in errors_lvl1:
                        print(error)
                else:
                    configs_lvl2, errors_lvl2, warnings_lvl2 = level_two_error_checking(
                        configs_lvl1)

                    if any(warnings_lvl2):
                        print('Level 2 warnings encountered.')
                        print(
                            "The following level 2 warnings were identified and corrected based on engine defaults:"
                        )
                        for warning in warnings_lvl2:
                            print(warning)

                    if any(errors_lvl2):
                        print('Level 2 errors encountered.')
                        print(
                            "Please fix the level 2 errors below before continuing:"
                        )
                        for error in errors_lvl2:
                            print(error)
                    else:
                        engine = Dlae(configs)
                        engine.run()
                        if any(engine.errors):
                            print('Level 3 errors encountered.')
                            print(
                                "Please fix the level 3 errors below before continuing:"
                            )
                            for error in engine.errors:
                                print(error)

                # check if the images and annotations are the correct files
                print(image_files[i], anno_files[i])

                pred_file = glob(
                    os.path.join(FLAGS.predictions_temp_dir, '*.h5'))[0]
                pt_name = image_files[i].split('.')[0]
                new_name_raw = pt_name + '_{}_{}_{}_{}_raw.h5'.format(
                    experiment[0], experiment[1], experiment[2],
                    1. - literal_eval(experiment[2]))
                new_file_raw = os.path.join(FLAGS.predictions_final_dir,
                                            new_name_raw)
                os.rename(pred_file, new_file_raw)

                ref = read_hdf5_multientry(
                    os.path.join(FLAGS.test_y_dir, anno_files[i]))
                ref = np.squeeze(np.asarray(ref))

                preds = read_hdf5(new_file_raw)
                if experiment[0] == 'UNet3D':
                    ref = np.transpose(ref, (1, 2, 0))

                    # stich the image back together first
                    sw = SlidingWindow(ref, [96, 96, 40], [128, 128, 48])
                    preds = sw.stitch_patches(preds, sw.window_corner_coords,
                                              [128, 128, 48], sw.img_shape,
                                              FLAGS.classes)
                    preds = np.argmax(preds, axis=-1)
                else:
                    preds = np.argmax(preds, axis=-1)

                overall_accuracy[i] = skm.accuracy_score(
                    ref.flatten(), preds.flatten())
                for j in range(FLAGS.classes):
                    organ_pred = (preds == j).astype(np.int64)
                    organ_ref = (ref == j).astype(np.int64)
                    if np.sum(organ_pred) == 0 or np.sum(organ_ref) == 0:
                        metrics[i, j, 0] = 0.
                        metrics[i, j, 1] = 0.
                        metrics[i, j, 2] = 1.
                        metrics[i, j, 3] = 0.
                        metrics[i, j, 4] = 0.
                        metrics[i, j, 5] = 0.
                        metrics[i, j, 6] = np.inf
                        metrics[i, j, 7] = np.inf
                    else:
                        metrics[i, j, 0] = jaccard_index(organ_ref, organ_pred)
                        metrics[i, j, 1] = dice_similarity_coefficient(
                            organ_ref, organ_pred)
                        metrics[i, j, 2] = relative_volume_difference(
                            organ_ref, organ_pred)
                        metrics[i, j, 3] = precision(organ_ref, organ_pred)
                        metrics[i, j, 4] = recall(organ_ref, organ_pred)
                        metrics[i, j, 5] = matthews_correlation_coefficient(
                            organ_ref, organ_pred)
                        metrics[i, j, 6] = mpm.hd95(organ_pred, organ_ref)
                        metrics[i, j, 7] = mpm.assd(organ_pred, organ_ref)
                print(overall_accuracy[i])
                print(metrics[i])

            else:
                pass

        if do_eval is True:
            for k in range(metrics.shape[-1]):
                data = pd.DataFrame(
                    metrics[:, :, k],
                    columns=['bg', 'pros', 'eus', 'sv', 'rect', 'blad'])
                data.to_excel(writer, sheet_name=str(k))
            acc = pd.DataFrame(overall_accuracy, columns=['acc'])
            acc.to_excel(writer, sheet_name='acc')
            writer.save()