Exemple #1
0
def CalculateWTTCET(wtpbregion, wtmaskregion, tcpbregion, tcmaskregion,
                    etpbregion, etmaskregion):
    #开始计算WT
    dice = dice_coef(wtpbregion, wtmaskregion)
    wt_dices.append(dice)
    ppv_n = ppv(wtpbregion, wtmaskregion)
    wt_ppvs.append(ppv_n)
    Hausdorff = hausdorff_distance(wtmaskregion, wtpbregion)
    wt_Hausdorf.append(Hausdorff)
    sensitivity_n = sensitivity(wtpbregion, wtmaskregion)
    wt_sensitivities.append(sensitivity_n)
    # 开始计算TC
    dice = dice_coef(tcpbregion, tcmaskregion)
    tc_dices.append(dice)
    ppv_n = ppv(tcpbregion, tcmaskregion)
    tc_ppvs.append(ppv_n)
    Hausdorff = hausdorff_distance(tcmaskregion, tcpbregion)
    tc_Hausdorf.append(Hausdorff)
    sensitivity_n = sensitivity(tcpbregion, tcmaskregion)
    tc_sensitivities.append(sensitivity_n)
    # 开始计算ET
    dice = dice_coef(etpbregion, etmaskregion)
    et_dices.append(dice)
    ppv_n = ppv(etpbregion, etmaskregion)
    et_ppvs.append(ppv_n)
    Hausdorff = hausdorff_distance(etmaskregion, etpbregion)
    et_Hausdorf.append(Hausdorff)
    sensitivity_n = sensitivity(etpbregion, etmaskregion)
    et_sensitivities.append(sensitivity_n)
def test(model, test_inputs, test_labels):
    """
    :param model: tf.keras.Model inherited data type
        model being trained  
    :param test_input: Numpy Array - shape (num_images, imsize, imsize, channels)
        input images to test on
    :param test_labels: Numpy Array - shape (num_images, 2)
        ground truth labels one-hot encoded
    :return: float, float, float, float 
        returns dice score, sensitivity value (0.5 threshold), specificity value (0.5 threshold), 
        and precision value all of which are in the range [0,1]
    """
    BATCH_SZ = model.batch_size
    indices = np.arange(test_inputs.shape[0]).tolist()
    all_logits = None
    for i in range(0, test_labels.shape[0], BATCH_SZ):
        images = test_inputs[indices[i:i + BATCH_SZ]]
        logits = model(images)
        if type(all_logits) == type(None):
            all_logits = logits
        else:
            all_logits = np.concatenate([all_logits, logits], axis=0)
    """this should break if the dataset size isnt divisible by the batch size because
    the for loop it runs the batches on doesnt get predictions for the remainder"""
    sensitivity_val1 = sensitivity(test_labels, all_logits, threshold=0.15)
    sensitivity_val2 = sensitivity(test_labels, all_logits, threshold=0.3)
    sensitivity_val3 = sensitivity(test_labels, all_logits, threshold=0.5)
    specificity_val1 = specificity(test_labels, all_logits, threshold=0.15)
    specificity_val2 = specificity(test_labels, all_logits, threshold=0.3)
    specificity_val3 = specificity(test_labels, all_logits, threshold=0.5)

    dice = dice_coef(test_labels, all_logits)
    precision_val = precision(test_labels, all_logits)
    print(
        "Sensitivity 0.15: {}, Senstivity 0.3: {}, Senstivity 0.5: {}".format(
            sensitivity_val1, sensitivity_val2, sensitivity_val3))
    print("Specificity 0.15: {}, Specificity 0.3: {}, Specificity 0.5: {}".
          format(specificity_val1, specificity_val2, specificity_val3))
    print("DICE: {}, Precision: {}".format(dice, precision_val))

    return dice.numpy(), sensitivity_val3, specificity_val3, precision_val
def train(model, generator, verbose=False):
    """trains the model for one epoch

    :param model: tf.keras.Model inherited data type
        model being trained 
    :param generator: BalancedDataGenerator
        a datagenerator which runs preprocessing and returns batches accessed
        by integers indexing (i.e. generator[0] returns the first batch of inputs 
        and labels)
    :param verbose: boolean
        whether to output the dice score every batch
    :return: list
        list of losses from every batch of training
    """
    BATCH_SZ = model.batch_size
    train_steps = generator.steps_per_epoch
    loss_list = []
    for i in range(0, train_steps, 1):
        images, labels = generator[i]
        with tf.GradientTape() as tape:
            logits = model(images)
            loss = model.loss_function(labels, logits)
        if i % 4 == 0 and verbose:
            sensitivity_val = sensitivity(labels, logits)
            specificity_val = specificity(labels, logits)
            precision_val = precision(labels, logits)
            train_dice = dice_coef(labels, logits)
            print("Scores on training batch after {} training steps".format(i))
            print("Sensitivity1: {}, Specificity: {}".format(
                sensitivity_val, specificity_val))
            print("Precision: {}, DICE: {}\n".format(precision_val,
                                                     train_dice))

        loss_list.append(loss)
        gradients = tape.gradient(loss, model.trainable_variables)
        model.optimizer.apply_gradients(
            zip(gradients, model.trainable_variables))

    return loss_list
Exemple #4
0
    transform_path0 = os.path.join(result_path, 'TransformParameters.0.txt')
    transform_path1 = os.path.join(result_path, 'TransformParameters.1.txt')
    final_transform_path = os.path.join(result_path, 'transform_pathfinal.txt')

    # Change FinalBSplineInterpolationOrder to 0 for binary mask transformation
    TransformParameterFileEditor(transform_path1, transform_path0, final_transform_path).modify_transform_parameter_file()

    # Make a new transformix object tr with the CORRECT PATH to transformix
    tr = elastix.TransformixInterface(parameters=final_transform_path,
                                      transformix_path=TRANSFORMIX_PATH)

    transformed_pr_path = tr.transform_image(pr_image_path, output_dir=result_path)
    image_array_tpr = sitk.GetArrayFromImage(sitk.ReadImage(transformed_pr_path))

    log_path = os.path.join(result_path, 'IterationInfo.1.R3.txt')
    log = elastix.logfile(log_path)

    DSC.append(dice_coef(image_array_opr, image_array_tpr))
    SNS.append(sensitivity(image_array_opr, image_array_tpr))
    SPC.append(specificity(image_array_opr, image_array_tpr))
    finalMI.append(statistics.mean(log['metric'][-50:-1]))

fig, (ax1,ax2,ax3) = plt.subplots(1, 3, figsize=(15, 5))
ax1.scatter(finalMI,DSC)
ax1.set_title("DSC")
ax2.scatter(finalMI,SNS)
ax2.set_title("SNS")
ax3.scatter(finalMI,SPC)
ax3.set_title("SPC")
plt.show()
Exemple #5
0
def main():
    val_args = parse_args()

    args = joblib.load('models/%s/args.pkl' % val_args.name)

    if not os.path.exists('output/%s' % args.name):
        os.makedirs('output/%s' % args.name)

    print('Config -----')
    for arg in vars(args):
        print('%s: %s' % (arg, getattr(args, arg)))
    print('------------')

    joblib.dump(args, 'models/%s/args.pkl' % args.name)

    # create model
    print("=> creating model %s" % args.arch)
    model = mymodel.__dict__[args.arch](args)

    model = model.cuda()

    # Data loading code
    img_paths = glob(
        r'D:\Project\CollegeDesign\dataset\Brats2018FoulModel2D\testImage\*')
    mask_paths = glob(
        r'D:\Project\CollegeDesign\dataset\Brats2018FoulModel2D\testMask\*')

    val_img_paths = img_paths
    val_mask_paths = mask_paths

    #train_img_paths, val_img_paths, train_mask_paths, val_mask_paths = \
    #   train_test_split(img_paths, mask_paths, test_size=0.2, random_state=41)

    model.load_state_dict(torch.load('models/%s/model.pth' % args.name))
    model.eval()

    val_dataset = Dataset(args, val_img_paths, val_mask_paths)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             drop_last=False)

    if val_args.mode == "GetPicture":
        """
        获取并保存模型生成的标签图
        """
        with warnings.catch_warnings():
            warnings.simplefilter('ignore')

            with torch.no_grad():
                for i, (input, target) in tqdm(enumerate(val_loader),
                                               total=len(val_loader)):
                    input = input.cuda()
                    #target = target.cuda()

                    # compute output
                    if args.deepsupervision:
                        output = model(input)[-1]
                    else:
                        output = model(input)
                    #print("img_paths[i]:%s" % img_paths[i])
                    output = torch.sigmoid(output).data.cpu().numpy()
                    img_paths = val_img_paths[args.batch_size *
                                              i:args.batch_size * (i + 1)]
                    #print("output_shape:%s"%str(output.shape))

                    for i in range(output.shape[0]):
                        """
                        生成灰色圖片
                        wtName = os.path.basename(img_paths[i])
                        overNum = wtName.find(".npy")
                        wtName = wtName[0:overNum]
                        wtName = wtName + "_WT" + ".png"
                        imsave('output/%s/'%args.name + wtName, (output[i,0,:,:]*255).astype('uint8'))
                        tcName = os.path.basename(img_paths[i])
                        overNum = tcName.find(".npy")
                        tcName = tcName[0:overNum]
                        tcName = tcName + "_TC" + ".png"
                        imsave('output/%s/'%args.name + tcName, (output[i,1,:,:]*255).astype('uint8'))
                        etName = os.path.basename(img_paths[i])
                        overNum = etName.find(".npy")
                        etName = etName[0:overNum]
                        etName = etName + "_ET" + ".png"
                        imsave('output/%s/'%args.name + etName, (output[i,2,:,:]*255).astype('uint8'))
                        """
                        npName = os.path.basename(img_paths[i])
                        overNum = npName.find(".npy")
                        rgbName = npName[0:overNum]
                        rgbName = rgbName + ".png"
                        rgbPic = np.zeros([160, 160, 3], dtype=np.uint8)
                        for idx in range(output.shape[2]):
                            for idy in range(output.shape[3]):
                                if output[i, 0, idx, idy] > 0.5:
                                    rgbPic[idx, idy, 0] = 0
                                    rgbPic[idx, idy, 1] = 128
                                    rgbPic[idx, idy, 2] = 0
                                if output[i, 1, idx, idy] > 0.5:
                                    rgbPic[idx, idy, 0] = 255
                                    rgbPic[idx, idy, 1] = 0
                                    rgbPic[idx, idy, 2] = 0
                                if output[i, 2, idx, idy] > 0.5:
                                    rgbPic[idx, idy, 0] = 255
                                    rgbPic[idx, idy, 1] = 255
                                    rgbPic[idx, idy, 2] = 0
                        imsave('output/%s/' % args.name + rgbName, rgbPic)

            torch.cuda.empty_cache()
        """
        将验证集中的GT numpy格式转换成图片格式并保存
        """
        print("Saving GT,numpy to picture")
        val_gt_path = 'output/%s/' % args.name + "GT/"
        if not os.path.exists(val_gt_path):
            os.mkdir(val_gt_path)
        for idx in tqdm(range(len(val_mask_paths))):
            mask_path = val_mask_paths[idx]
            name = os.path.basename(mask_path)
            overNum = name.find(".npy")
            name = name[0:overNum]
            rgbName = name + ".png"

            npmask = np.load(mask_path)

            GtColor = np.zeros([npmask.shape[0], npmask.shape[1], 3],
                               dtype=np.uint8)
            for idx in range(npmask.shape[0]):
                for idy in range(npmask.shape[1]):
                    #坏疽(NET,non-enhancing tumor)(标签1) 红色
                    if npmask[idx, idy] == 1:
                        GtColor[idx, idy, 0] = 255
                        GtColor[idx, idy, 1] = 0
                        GtColor[idx, idy, 2] = 0
                    #浮肿区域(ED,peritumoral edema) (标签2) 绿色
                    elif npmask[idx, idy] == 2:
                        GtColor[idx, idy, 0] = 0
                        GtColor[idx, idy, 1] = 128
                        GtColor[idx, idy, 2] = 0
                    #增强肿瘤区域(ET,enhancing tumor)(标签4) 黄色
                    elif npmask[idx, idy] == 4:
                        GtColor[idx, idy, 0] = 255
                        GtColor[idx, idy, 1] = 255
                        GtColor[idx, idy, 2] = 0

            #imsave(val_gt_path + rgbName, GtColor)
            imageio.imwrite(val_gt_path + rgbName, GtColor)
            """
            mask_path = val_mask_paths[idx]
            name = os.path.basename(mask_path)
            overNum = name.find(".npy")
            name = name[0:overNum]
            wtName = name + "_WT" + ".png"
            tcName = name + "_TC" + ".png"
            etName = name + "_ET" + ".png"

            npmask = np.load(mask_path)

            WT_Label = npmask.copy()
            WT_Label[npmask == 1] = 1.
            WT_Label[npmask == 2] = 1.
            WT_Label[npmask == 4] = 1.
            TC_Label = npmask.copy()
            TC_Label[npmask == 1] = 1.
            TC_Label[npmask == 2] = 0.
            TC_Label[npmask == 4] = 1.
            ET_Label = npmask.copy()
            ET_Label[npmask == 1] = 0.
            ET_Label[npmask == 2] = 0.
            ET_Label[npmask == 4] = 1.

            imsave(val_gt_path + wtName, (WT_Label * 255).astype('uint8'))
            imsave(val_gt_path + tcName, (TC_Label * 255).astype('uint8'))
            imsave(val_gt_path + etName, (ET_Label * 255).astype('uint8'))
            """
        print("Done!")

    if val_args.mode == "Calculate":
        """
        计算各种指标:Dice、Sensitivity、PPV
        """
        wt_dices = []
        tc_dices = []
        et_dices = []
        wt_sensitivities = []
        tc_sensitivities = []
        et_sensitivities = []
        wt_ppvs = []
        tc_ppvs = []
        et_ppvs = []
        wt_Hausdorf = []
        tc_Hausdorf = []
        et_Hausdorf = []

        wtMaskList = []
        tcMaskList = []
        etMaskList = []
        wtPbList = []
        tcPbList = []
        etPbList = []

        maskPath = glob("output/%s/" % args.name + "GT\*.png")
        pbPath = glob("output/%s/" % args.name + "*.png")
        if len(maskPath) == 0:
            print("请先生成图片!")
            return

        for myi in tqdm(range(len(maskPath))):
            mask = imread(maskPath[myi])
            pb = imread(pbPath[myi])

            wtmaskregion = np.zeros([mask.shape[0], mask.shape[1]],
                                    dtype=np.float32)
            wtpbregion = np.zeros([mask.shape[0], mask.shape[1]],
                                  dtype=np.float32)

            tcmaskregion = np.zeros([mask.shape[0], mask.shape[1]],
                                    dtype=np.float32)
            tcpbregion = np.zeros([mask.shape[0], mask.shape[1]],
                                  dtype=np.float32)

            etmaskregion = np.zeros([mask.shape[0], mask.shape[1]],
                                    dtype=np.float32)
            etpbregion = np.zeros([mask.shape[0], mask.shape[1]],
                                  dtype=np.float32)

            for idx in range(mask.shape[0]):
                for idy in range(mask.shape[1]):
                    # 只要这个像素的任何一个通道有值,就代表这个像素不属于前景,即属于WT区域
                    if mask[idx, idy, :].any() != 0:
                        wtmaskregion[idx, idy] = 1
                    if pb[idx, idy, :].any() != 0:
                        wtpbregion[idx, idy] = 1
                    # 只要第一个通道是255,即可判断是TC区域,因为红色和黄色的第一个通道都是255,区别于绿色
                    if mask[idx, idy, 0] == 255:
                        tcmaskregion[idx, idy] = 1
                    if pb[idx, idy, 0] == 255:
                        tcpbregion[idx, idy] = 1
                    # 只要第二个通道是128,即可判断是ET区域
                    if mask[idx, idy, 1] == 128:
                        etmaskregion[idx, idy] = 1
                    if pb[idx, idy, 1] == 128:
                        etpbregion[idx, idy] = 1
            #开始计算WT
            dice = dice_coef(wtpbregion, wtmaskregion)
            wt_dices.append(dice)
            ppv_n = ppv(wtpbregion, wtmaskregion)
            wt_ppvs.append(ppv_n)
            Hausdorff = hausdorff_distance(wtmaskregion, wtpbregion)
            wt_Hausdorf.append(Hausdorff)
            sensitivity_n = sensitivity(wtpbregion, wtmaskregion)
            wt_sensitivities.append(sensitivity_n)
            # 开始计算TC
            dice = dice_coef(tcpbregion, tcmaskregion)
            tc_dices.append(dice)
            ppv_n = ppv(tcpbregion, tcmaskregion)
            tc_ppvs.append(ppv_n)
            Hausdorff = hausdorff_distance(tcmaskregion, tcpbregion)
            tc_Hausdorf.append(Hausdorff)
            sensitivity_n = sensitivity(tcpbregion, tcmaskregion)
            tc_sensitivities.append(sensitivity_n)
            # 开始计算ET
            dice = dice_coef(etpbregion, etmaskregion)
            et_dices.append(dice)
            ppv_n = ppv(etpbregion, etmaskregion)
            et_ppvs.append(ppv_n)
            Hausdorff = hausdorff_distance(etmaskregion, etpbregion)
            et_Hausdorf.append(Hausdorff)
            sensitivity_n = sensitivity(etpbregion, etmaskregion)
            et_sensitivities.append(sensitivity_n)

        print('WT Dice: %.4f' % np.mean(wt_dices))
        print('TC Dice: %.4f' % np.mean(tc_dices))
        print('ET Dice: %.4f' % np.mean(et_dices))
        print("=============")
        print('WT PPV: %.4f' % np.mean(wt_ppvs))
        print('TC PPV: %.4f' % np.mean(tc_ppvs))
        print('ET PPV: %.4f' % np.mean(et_ppvs))
        print("=============")
        print('WT sensitivity: %.4f' % np.mean(wt_sensitivities))
        print('TC sensitivity: %.4f' % np.mean(tc_sensitivities))
        print('ET sensitivity: %.4f' % np.mean(et_sensitivities))
        print("=============")
        print('WT Hausdorff: %.4f' % np.mean(wt_Hausdorf))
        print('TC Hausdorff: %.4f' % np.mean(tc_Hausdorf))
        print('ET Hausdorff: %.4f' % np.mean(et_Hausdorf))
        print("=============")
def evaluate_metrics(model, loader, vali_data, batchsize, dim_input_course,
                     dim_input_grade, dim_input_major):

    model.eval()
    summ1 = 0  # >=B or <B
    summ2 = 0  # credit/uncredit

    len1 = len2 = 0
    tp = np.zeros(2)
    tn = np.zeros(2)
    true = np.zeros(2)
    false = np.zeros(2)
    predict_true = np.zeros(2)
    predict_false = np.zeros(2)
    for step, (batch_x,
               batch_y) in enumerate(loader):  # batch_x: index of batch data
        processed_data = process_data(batch_x.numpy(), vali_data, batchsize,
                                      dim_input_course, dim_input_grade,
                                      dim_input_major)
        padded_input = Variable(torch.Tensor(processed_data[0]),
                                requires_grad=False).cuda()
        seq_len = processed_data[1]
        padded_label = Variable(torch.Tensor(processed_data[2]),
                                requires_grad=False).cuda()

        # clear hidden states
        model.hidden = model.init_hidden()
        model.hidden[0] = model.hidden[0].cuda()
        model.hidden[1] = model.hidden[1].cuda()
        # compute output
        y_pred = model(padded_input, seq_len)

        # only compute the accuracy for testing period
        accura = accuracy(y_pred, seq_len, padded_label)
        len1 += accura[3]
        len2 += accura[4]
        summ1 += (accura[0] * accura[3])
        summ2 += (accura[1] * accura[4])

        print('>=B or not', accura[0], 'credit/uncredit', accura[1], 'total',
              accura[2])

        # compute tp, fp, fn, tn
        sen = sensitivity(y_pred, seq_len, padded_label)
        tp += sen[0]
        tn += sen[1]
        true += sen[2]
        false += sen[3]
        predict_true += sen[4]
        predict_false += sen[5]

    average_metric1 = summ1 / len1
    average_metric2 = summ2 / len2
    average_metric = (summ1 + summ2) / (len1 + len2)

    print("num of >=B or <B: ", len1, "num of credit/uncredit: ", len2)
    print("On average: ", average_metric1, average_metric2, average_metric)

    tpr = tp / true
    fpr = (predict_true - tp) / false
    fnr = (predict_false - tn) / true
    tnr = tn / false

    precision_B = (tn / predict_false)[0]
    f_value_B = 2 / (1 / tnr[0] + 1 / precision_B)
    precision_uncredit = (tn / predict_false)[-1]
    f_value_uncredit = 2 / (1 / tnr[-1] + 1 / precision_uncredit)
    f_value = np.append(f_value_B, f_value_uncredit)
    print("tpr: ", tpr)
    print("fpr: ", fpr)
    print("fnr: ", fnr)
    print("tnr: ", tnr)
    print('F: ', f_value, 'average F:', np.average(f_value))
def get_scores(y_true, y_predict):
    return dice_coef(y_true, y_predict), sensitivity(y_true, y_predict), specificity(y_true, y_predict), MeanSurfaceDistance(y_true, y_predict), mutual_information(y_true, y_predict), rmse(y_true, y_predict)
Exemple #8
0
# fig, (ax6, ax7) = plt.subplots(1,2)
# ScrollView(J_binarized).plot(ax6, vmin=0, vmax=1)
# plt.title('Jacobian\ndeterminant')
#
# ScrollView(image_array_J).plot(ax7)
# ax7.set_title('Jacobian\ndeterminant')

fig, (ax8, ax9, ax10) = plt.subplots(1, 3)

ScrollView(image_array_opr).plot(ax8, vmin=0, vmax=1)
ax8.set_title("Unseen segmentation")

transformed_pr_path = tr.transform_image(pr_image_path, output_dir=r'results')
itk_image_tpr = sitk.ReadImage(transformed_pr_path)
image_array_tpr = sitk.GetArrayFromImage(itk_image_tpr)

ScrollView(image_array_tpr).plot(ax9, vmin=0, vmax=1)
ax9.set_title("Transformed patient segmentation")

segmentation_abs = abs(image_array_tpr - image_array_opr)
ScrollView(segmentation_abs).plot(ax10, vmin=0, vmax=1)
ax10.set_title("Absolute differences")

DSC = dice_coef(image_array_opr, image_array_tpr)
SNS = sensitivity(image_array_opr, image_array_tpr)
SPC = specificity(image_array_opr, image_array_tpr)

print("Dice coefficient is %.2f, sensitivity is %.2f, specificity is %.2f" %
      (DSC, SNS, SPC))

plt.show()
Exemple #9
0
def dice_sens_loss(y_true, y_pred, alpha=0.5):
    return -metrics.dice_coef(y_true, y_pred) - alpha * metrics.sensitivity(
        y_true, y_pred)
Exemple #10
0
def main():
    # Set hyperparameters
    num_folds = 100
    label_name = "1"

    # Specify data location
    data_path = "Data/test_data.csv"

    # Load data to table
    df = pd.read_csv(data_path, sep=";", index_col=0)

    # Check if any labels are missing
    print("Number of missing values:\n", df.isnull().sum())
    print()

    # Only keep first instance if multiple instances have the same key
    num_instances_before = len(df)
    df = df[~df.index.duplicated(keep="first")]
    num_instances_diff = num_instances_before - len(df)
    if num_instances_diff > 0:
        print(
            "Warning: {} instances removed due to duplicate keys - only keeping first occurrence!"
            .format(num_instances_diff))

    # Perform standardized preprocessing
    preprocessor = TabularPreprocessor()
    df = preprocessor.fit_transform(df)

    # Display bar chart with number of samples per class
    # seaborn.countplot(x=label_name, data=df)
    # plt.title("Original class frequencies")
    # plt.savefig("Results/original_class_frequencies.png")
    # plt.close()

    # Separate data into training and test
    y = df[label_name]
    x = df.drop(label_name, axis="columns")

    # Get samples per class
    print("Samples per class")
    for (label, count) in zip(*np.unique(y, return_counts=True)):
        print("{}: {}".format(label, count))
    print()

    # Get number of classes
    num_classes = len(np.unique(df[label_name].values))

    # Setup classifiers
    knn = KNeighborsClassifier(weights="distance")
    knn_param_grid = {
        "n_neighbors":
        [int(val)
         for val in np.round(np.sqrt(x.shape[1])) + np.arange(5) + 1] +
        [
            int(val)
            for val in np.round(np.sqrt(x.shape[1])) - np.arange(5) if val >= 1
        ],
        "p":
        np.arange(1, 5)
    }

    dt = DecisionTreeClassifier()
    dt_param_grid = {
        "criterion": ["gini", "entropy"],
        "splitter": ["best", "random"],
        "max_depth": np.arange(1, 20),
        "min_samples_split": [2, 4, 6],
        "min_samples_leaf": [1, 3, 5, 6],
        "max_features": ["auto", "sqrt", "log2"]
    }

    rf = RandomForestClassifier(n_estimators=100,
                                criterion="entropy",
                                max_depth=5,
                                min_samples_split=5,
                                min_samples_leaf=2)
    rf_param_grid = {}

    nn = MLPClassifier(hidden_layer_sizes=(32, 64, 32), activation="relu")
    nn_param_grid = {}

    clfs = {
        "knn": {
            "classifier": knn,
            "parameters": knn_param_grid
        },
        "dt": {
            "classifier": dt,
            "parameters": dt_param_grid
        },
        "rf": {
            "classifier": rf,
            "parameters": rf_param_grid
        },
        "nn": {
            "classifier": nn,
            "parameters": nn_param_grid
        }
    }

    clfs_performance = {"acc": [], "sns": [], "spc": [], "auc": []}

    # Initialize result table
    results = pd.DataFrame(index=list(clfs.keys()))

    # Iterate over classifiers
    for clf in clfs:

        # Initialize cumulated confusion matrix and fold-wise performance containers
        cms = np.zeros((num_classes, num_classes))
        performance_foldwise = {"acc": [], "sns": [], "spc": [], "auc": []}

        # Iterate over MCCV
        for fold_index in np.arange(num_folds):

            # Split into training and test data
            x_train, x_test, y_train, y_test = train_test_split(
                x, y, test_size=0.15, stratify=y, random_state=fold_index)

            # Perform standardization and feature imputation
            intra_fold_preprocessor = TabularIntraFoldPreprocessor(
                k="automated", normalization="standardize")
            intra_fold_preprocessor = intra_fold_preprocessor.fit(x_train)
            x_train = intra_fold_preprocessor.transform(x_train)
            x_test = intra_fold_preprocessor.transform(x_test)

            # Perform (ANOVA) feature selection
            selected_indices, x_train, x_test = univariate_feature_selection(
                x_train.values,
                y_train.values,
                x_test.values,
                score_func=f_classif,
                num_features="log2n")

            # # Random undersampling
            # rus = RandomUnderSampler(random_state=fold_index, sampling_strategy=0.3)
            # x_train, y_train = rus.fit_resample(x_train, y_train)

            # SMOTE
            smote = SMOTE(random_state=fold_index, sampling_strategy=1)
            x_train, y_train = smote.fit_resample(x_train, y_train)

            # Setup model
            model = clfs[clf]["classifier"]
            model.random_state = fold_index

            # Hyperparameter tuning and keep model trained with the best set of hyperparameters
            optimized_model = RandomizedSearchCV(
                model,
                param_distributions=clfs[clf]["parameters"],
                cv=5,
                random_state=fold_index)
            optimized_model.fit(x_train, y_train)

            # Predict test data using trained model
            y_pred = optimized_model.predict(x_test)

            # Compute performance
            cm = confusion_matrix(y_test, y_pred)
            acc = accuracy_score(y_test, y_pred)
            sns = metrics.sensitivity(y_test, y_pred)
            spc = metrics.specificity(y_test, y_pred)
            auc = metrics.roc_auc(y_test, y_pred)

            # Append performance to fold-wise and overall containers
            cms += cm
            performance_foldwise["acc"].append(acc)
            performance_foldwise["sns"].append(sns)
            performance_foldwise["spc"].append(spc)
            performance_foldwise["auc"].append(auc)

        # Calculate overall performance
        for metric in performance_foldwise:
            avg_metric = np.round(
                np.sum(performance_foldwise[metric]) /
                len(performance_foldwise[metric]), 2)
            clfs_performance[metric].append(avg_metric)

        # Display overall performances
        print("== {} ==".format(clf))
        print("Cumulative CM:\n", cms)
        for metric in clfs_performance:
            print("Avg {}: {}".format(metric, clfs_performance[metric][-1]))
        print()

        # Display confusion matrix
        # sns.heatmap(cms, annot=True, cmap="Blues", fmt="g")
        # plt.xlabel("Predicted")
        # plt.ylabel("Actual")
        # plt.title("{} - Confusion matrix".format(clf))
        # plt.savefig("Results/confusion_matrix-{}.png".format(clf))
        # plt.close()

    # Append performance to result table
    for metric in clfs_performance:
        results[metric] = clfs_performance[metric]

    # Save result table
    results.to_csv("performances.csv", sep=";")
    results.plot.bar(rot=45).legend(loc="upper right")
    plt.savefig("performance.png".format(clf))
    plt.show()
    plt.close()
masks = [sitk.GetArrayFromImage(sitk.ReadImage(os.path.join(data_path, patient, "prostaat.mhd"))) for patient in patients]
images = [sitk.GetArrayFromImage(sitk.ReadImage(os.path.join(data_path, patient, "mr_bffe.mhd"))) for patient in patients if patient.find("p1")>-1]

#specify unknown image & mask
unknown_mask=masks.pop()
unknown_image=images.pop()

#calculate mean of masks
mask_mean = np.sum(masks, axis=0)/np.shape(masks)[0] #only used to visualize the mean mask

#calculate majority voting combination of masks
st = time()
m1 = majority_voting(masks, 0.5)
d_m1 = st - time()

DSC_m1, SNS_m1, SPC_m1 = dice_coef(unknown_mask, m1), sensitivity(unknown_mask, m1), specificity(unknown_mask, m1)

#calculate global weighted voting combination of masks
st = time()
w1 = global_weighted_voting(images, masks, unknown_image, 0.5)
d_w1 = st - time()

DSC_w1, SNS_w1, SPC_w1 = dice_coef(unknown_mask, w1), sensitivity(unknown_mask, w1), specificity(unknown_mask, w1)

#calculate local weighted voting combination of masks
st = time()
g1 = local_weighted_voting(images, masks, unknown_image, 0.5, max_idx = 50000)
d_g1 = st - time()

DSC_g1, SNS_g1, SPC_g1 = dice_coef(unknown_mask, g1), sensitivity(unknown_mask, g1), specificity(unknown_mask, g1)