Esempio n. 1
0
def ComputeDiceWNet(result_path, t2_folder, data_type, th=0.3):
    label_arr = np.load(
        os.path.join(result_path, '{}_label1.npy'.format(data_type)))
    preds_arr1 = np.load(
        os.path.join(result_path, '{}_preds1.npy'.format(data_type)))
    preds_arr2 = np.load(
        os.path.join(result_path, '{}_preds2.npy'.format(data_type)))

    dice1_list, dice2_list, dice3_list, dice4_list = [], [], [], []
    case_name = pd.read_csv(
        r'/home/zhangyihong/Documents/ProstateX_Seg_ZYH/OneSlice/{}_name.csv'.
        format(data_type))
    case_list = sorted(case_name.loc[0].tolist())
    for index in range(preds_arr1.shape[0]):
        # t2 = np.squeeze(np.load(os.path.join(t2_folder, case_list[index] + '.npy')))
        # t2_crop, _ = ExtractPatch(t2, (200, 200))
        pred1 = ROIOneHot(np.argmax(preds_arr1[index], axis=0))
        pred2 = ROIOneHot(np.argmax(preds_arr2[index], axis=0))
        label = label_arr[index]

        dice1_list.append(
            Dice((preds_arr1[index] > th).astype(int), label_arr1[index]))
        dice2_list.append(
            Dice((preds_arr2[index] > th).astype(int), label_arr2[index]))
        dice3_list.append(
            Dice((preds_arr3[index] > th).astype(int), label_arr3[index]))
        dice4_list.append(
            Dice((preds_arr4[index] > th).astype(int), label_arr4[index]))

    return dice1_list, dice2_list, dice3_list, dice4_list
Esempio n. 2
0
    def case_test():
        seg = ProstateXSeg((200, 200))
        data_path = r'/home/zhangyihong/Documents/ProstateX_Seg_ZYH/OriginalData'
        model_folder = r'/home/zhangyihong/Documents/ProstateX_Seg_ZYH/Model'

        model = UNet25D(n_channels=1, n_classes=5, bilinear=True, factor=2)

        dice1_list, dice2_list, dice3_list, dice4_list, dice5_list = [], [], [], [], []

        # df = pd.read_csv(os.path.join(r'/home/zhangyihong/Documents/ProstateX_Seg_ZYH/OneSlice', 'train_case_name.csv'))
        # df = pd.read_csv(os.path.join(r'/home/zhangyihong/Documents/ProstateX_Seg_ZYH/OneSlice', 'val_case_name.csv'))
        df = pd.read_csv(
            os.path.join(
                r'/home/zhangyihong/Documents/ProstateX_Seg_ZYH/OneSlice',
                'test_case_name.csv'))
        case_list = df.values.tolist()[0]
        for case in case_list:
            t2_arr, roi_arr = seg.Nii2NPY(case,
                                          os.path.join(data_path, case),
                                          r'',
                                          slice_num=3)
            # UNet
            # 13-2.679143.pt
            # 26-1.161448.pt   UNet_0330
            preds = seg.run(case,
                            model,
                            model_path=os.path.join(model_folder,
                                                    'UNet_0330_weightedloss'),
                            weights_path='13-2.679143.pt',
                            inputs=t2_arr,
                            outputs=roi_arr,
                            is_save=False)
            if isinstance(preds, tuple):
                preds = preds[-1]
            preds = torch.softmax(preds, dim=1)
            preds = torch.argmax(preds, dim=1).cpu().data.numpy()
            preds = ROIOneHot(preds)
            # roi_arr = np.argmax(roi_arr, axis=1)
            # roi_arr = np.clip(roi_arr, a_min=0, a_max=2)
            # roi_arr = ROIOneHot(roi_arr, roi_class=[0, 1, 2])

            dice1_list.append(Dice(preds[:, 0, ...], roi_arr[:, 0, ...]))
            dice2_list.append(Dice(preds[:, 1, ...], roi_arr[:, 1, ...]))
            dice3_list.append(Dice(preds[:, 2, ...], roi_arr[:, 2, ...]))
            dice4_list.append(Dice(preds[:, 3, ...], roi_arr[:, 3, ...]))
            dice5_list.append(Dice(preds[:, 4, ...], roi_arr[:, 4, ...]))

        print('{:.3f}, {:.3f}, {:.3f}, {:.3f}, {:.3f}'.format(
            sum(dice1_list) / len(dice1_list),
            sum(dice2_list) / len(dice2_list),
            sum(dice3_list) / len(dice3_list),
            sum(dice4_list) / len(dice4_list),
            sum(dice5_list) / len(dice5_list)))
Esempio n. 3
0
def ComputeDice():
    dice0_list = []
    dice1_list = []
    dice2_list = []
    dice3_list = []
    dice4_list = []

    for case in os.listdir(case_folder):
        case_label = case[:case.index('_0000')] + '.nii.gz'
        case_path = os.path.join(case_folder, case)
        label_path = os.path.join(label_folder, case_label)
        predict_path = os.path.join(predict_folder, case_label)

        _, t2, _ = LoadImage(case_path)
        _, label, _ = LoadImage(label_path)
        label = ROIOneHot(label)
        _, prediction, _ = LoadImage(predict_path)
        prediction = ROIOneHot(prediction)

        bg_label = label[:, 0, :, :]
        pz_label = label[:, 1, :, :]
        cg_label = label[:, 2, :, :]
        u_label = label[:, 3, :, :]
        as_label = label[:, 4, :, :]

        bg_pred = prediction[:, 0, :, :]
        pz_pred = prediction[:, 1, :, :]
        cg_pred = prediction[:, 2, :, :]
        u_pred = prediction[:, 3, :, :]
        as_pred = prediction[:, 4, :, :]

        dice0_list.append(Dice(bg_pred, bg_label))
        dice1_list.append(Dice(pz_pred, pz_label))
        dice2_list.append(Dice(cg_pred, cg_label))
        dice3_list.append(Dice(u_pred, u_label))
        dice4_list.append(Dice(as_pred, as_label))
        # Imshow3DArray(Normalize01(t2), roi=[Normalize01(label), Normalize01(prediction)])
    print('{:.3f}, {:.3f}, {:.3f}, {:.3f}, {:.3f}'.format(
        sum(dice1_list) / len(dice1_list),
        sum(dice2_list) / len(dice2_list),
        sum(dice3_list) / len(dice3_list),
        sum(dice4_list) / len(dice4_list),
        sum(dice0_list) / len(dice0_list)))
Esempio n. 4
0
def Train():
    torch.autograd.set_detect_anomaly(True)

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    input_shape = (200, 200)
    total_epoch = 10000
    batch_size = 24
    model_folder = MakeFolder(model_root + '/UNet_0330_weigthedloss')

    ClearGraphPath(model_folder)

    param_config = {
        RotateTransform.name: {'theta': ['uniform', -10, 10]},
        ShiftTransform.name: {'horizontal_shift': ['uniform', -0.05, 0.05],
                              'vertical_shift': ['uniform', -0.05, 0.05]},
        ZoomTransform.name: {'horizontal_zoom': ['uniform', 0.95, 1.05],
                             'vertical_zoom': ['uniform', 0.95, 1.05]},
        FlipTransform.name: {'horizontal_flip': ['choice', True, False]},
        BiasTransform.name: {'center': ['uniform', -1., 1., 2],
                             'drop_ratio': ['uniform', 0., 1.]},
        NoiseTransform.name: {'noise_sigma': ['uniform', 0., 0.03]},
        ContrastTransform.name: {'factor': ['uniform', 0.8, 1.2]},
        GammaTransform.name: {'gamma': ['uniform', 0.8, 1.2]},
        ElasticTransform.name: ['elastic', 1, 0.1, 256]
    }

    train_df = pd.read_csv(os.path.join(data_root, 'train_name.csv'))
    train_list = train_df.values.tolist()[0]
    val_df = pd.read_csv(os.path.join(data_root, 'val_name.csv'))
    val_list = val_df.values.tolist()[0]

    train_loader, train_batches = _GetLoader(train_list, param_config, input_shape, batch_size, True)
    val_loader, val_batches = _GetLoader(val_list, param_config, input_shape, batch_size, True)

    model = UNet25D(n_channels=1, n_classes=5, bilinear=True, factor=2).to(device)
    model.apply(HeWeightInit)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    weight = torch.from_numpy(np.array([0.1, 0.8, 0.8, 1., 1.])).float()
    criterion1 = torch.nn.CrossEntropyLoss(weight=weight.to(device))
    criterion2 = DiceLoss()

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=10, factor=0.5,
                                                           verbose=True)
    early_stopping = EarlyStopping(store_path=str(model_folder / '{}-{:.6f}.pt'), patience=50, verbose=True)
    writer = SummaryWriter(log_dir=str(model_folder / 'log'), comment='Net')

    for epoch in range(total_epoch):
        train_dice, val_dice = [], []
        train_dice_pz, val_dice_pz = [], []
        train_dice_cg, val_dice_cg = [], []
        train_dice_U, val_dice_U = [], []
        train_dice_AFMS, val_dice_AFMS = [], []

        train_loss, val_loss = 0., 0.
        train_loss1, val_loss1 = 0., 0.
        train_loss2, val_loss2 = 0., 0.

        model.train()
        for ind, (inputs, outputs) in enumerate(train_loader):
            #
            outputs_nocoding = torch.argmax(outputs, dim=1)


            inputs = MoveTensorsToDevice(inputs, device)
            outputs_nocoding = MoveTensorsToDevice(outputs_nocoding, device)
            outputs = MoveTensorsToDevice(outputs.int(), device)

            preds = model(inputs)
            # Crossentropy Loss: preds: logits(没有做softmax)
            #                    labels: 没有做编码
            # Dice Loss: preds: 要做softmax
            #            labels: 要做编码,格式和preds相同
            softmax_preds = F.softmax(preds, dim=1)

            train_dice.append(Dice(softmax_preds.cpu().data.numpy(), outputs.cpu().data.numpy()))
            train_dice_pz.append(Dice(softmax_preds.cpu().data.numpy()[:, 1], outputs.cpu().data.numpy()[:, 1]))
            train_dice_cg.append(Dice(softmax_preds.cpu().data.numpy()[:, 2], outputs.cpu().data.numpy()[:, 2]))
            train_dice_U.append(Dice(softmax_preds.cpu().data.numpy()[:, 3], outputs.cpu().data.numpy()[:, 3]))
            train_dice_AFMS.append(Dice(softmax_preds.cpu().data.numpy()[:, 4], outputs.cpu().data.numpy()[:, 4]))

            loss1 = criterion1(preds, outputs_nocoding)
            loss2 = criterion2(softmax_preds, outputs)
            loss = loss1 + loss2

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss1 += loss1.item()
            train_loss2 += loss2.item()
            train_loss += loss.item()

        model.eval()
        with torch.no_grad():
            for ind, (inputs, outputs) in enumerate(val_loader):

                outputs_nocoding = torch.argmax(outputs, dim=1)

                inputs = MoveTensorsToDevice(inputs, device)
                outputs_nocoding = MoveTensorsToDevice(outputs_nocoding, device)
                outputs = MoveTensorsToDevice(outputs.int(), device)

                preds = model(inputs)
                softmax_preds = F.softmax(preds, dim=1)

                val_dice.append(Dice(softmax_preds.cpu().data.numpy(), outputs.cpu().data.numpy()))
                val_dice_pz.append(Dice(softmax_preds.cpu().data.numpy()[:, 1], outputs.cpu().data.numpy()[:, 1]))
                val_dice_cg.append(Dice(softmax_preds.cpu().data.numpy()[:, 2], outputs.cpu().data.numpy()[:, 2]))
                val_dice_U.append(Dice(softmax_preds.cpu().data.numpy()[:, 3], outputs.cpu().data.numpy()[:, 3]))
                val_dice_AFMS.append(Dice(softmax_preds.cpu().data.numpy()[:, 4], outputs.cpu().data.numpy()[:, 4]))

                loss1 = criterion1(preds, outputs_nocoding)
                loss2 = criterion2(softmax_preds, outputs)
                loss = loss1 + loss2

                val_loss += loss.item()
                val_loss1 += loss1.item()
                val_loss2 += loss2.item()

        # Save Tensor Board
        for index, (name, param) in enumerate(model.named_parameters()):
            if 'bn' not in name:
                writer.add_histogram(name + '_data', param.cpu().data.numpy(), epoch + 1)

        writer.add_scalars('Loss',
                           {'train_loss': train_loss / train_batches,
                            'val_loss': val_loss / val_batches}, epoch + 1)
        writer.add_scalars('Crossentropy Loss',
                           {'train_loss': train_loss1 / train_batches,
                            'val_loss': val_loss1 / val_batches}, epoch + 1)
        writer.add_scalars('Dice Loss',
                           {'train_loss': train_loss2 / train_batches,
                            'val_loss': val_loss2 / val_batches}, epoch + 1)

        writer.add_scalars('Dice',
                           {'train_loss': np.sum(train_dice) / len(train_dice),
                            'val_loss': np.sum(val_dice) / len(val_dice)}, epoch + 1)

        print('*************************************** Epoch {} | (◕ᴗ◕✿) ***************************************'.format(epoch + 1))
        print('    dice pz: {:.3f},     dice cg: {:.3f},     dice U: {:.3f},     dice AFMS: {:.3f}'.
              format(np.sum(train_dice_pz) / len(train_dice_pz), np.sum(train_dice_cg) / len(train_dice_cg),
                     np.sum(train_dice_U) / len(train_dice_U), np.sum(train_dice_AFMS) / len(train_dice_AFMS)))
        print('val-dice pz: {:.3f}, val-dice cg: {:.3f}, val-dice U: {:.3f}, val-dice AFMS: {:.3f}'.
              format(np.sum(val_dice_pz) / len(val_dice_pz), np.sum(val_dice_cg) / len(val_dice_cg),
                     np.sum(val_dice_U) / len(val_dice_U), np.sum(val_dice_AFMS) / len(val_dice_AFMS)))
        print()
        print('loss: {:.3f}, val-loss: {:.3f}'.format(train_loss / train_batches, val_loss / val_batches))

        scheduler.step(val_loss)
        early_stopping(val_loss, model, (epoch + 1, val_loss))

        if early_stopping.early_stop:
            print("Early stopping")
            break

        writer.flush()
        writer.close()
Esempio n. 5
0
def ComputeDice(result_path, t2_folder, data_type):
    if len(os.listdir(result_path)) == 9:
        preds_arr = np.load(
            os.path.join(result_path, '{}_preds2.npy'.format(data_type)))
        label_arr = np.load(
            os.path.join(result_path, '{}_label.npy'.format(data_type)))
    else:
        preds_arr = np.load(
            os.path.join(result_path, '{}_preds.npy'.format(data_type)))
        label_arr = np.load(
            os.path.join(result_path, '{}_label.npy'.format(data_type)))

    dice1_list, dice2_list, dice3_list, dice4_list, dice5_list = [], [], [], [], []
    # case_name = pd.read_csv(r'/home/zhangyihong/Documents/ProstateX_Seg_ZYH/OneSlice/{}_name.csv'.format(data_type))
    # case_list = sorted(case_name.loc[0].tolist())
    for index in range(preds_arr.shape[0]):
        pred = ROIOneHot(np.argmax(preds_arr[index], axis=0))
        label = label_arr[index]
        # if os.path.exists(os.path.join(t2_folder, case_list[index] + '.npy')):
        # t2 = np.squeeze(np.load(os.path.join(t2_folder, case_list[index] + '.npy')))
        # t2_crop, _ = ExtractPatch(t2, (200, 200))

        dice1_list.append(Dice(pred[0], label[0]))
        dice2_list.append(Dice(pred[1], label[1]))
        dice3_list.append(Dice(pred[2], label[2]))
        dice4_list.append(Dice(pred[3], label[3]))
        dice5_list.append(Dice(pred[4], label[4]))

        ########################################show result#############################################################
        # plt.subplot(231)
        # plt.axis('off')
        # plt.imshow(pred[0], cmap='gray', vmin=0., vmax=1.)
        # plt.contour(label[0], colors='r')
        # plt.subplot(232)
        # plt.axis('off')
        # plt.imshow(pred[1], cmap='gray', vmin=0., vmax=1.)
        # plt.contour(label[1], colors='r')
        # plt.subplot(234)
        # plt.axis('off')
        # plt.imshow(pred[2], cmap='gray', vmin=0., vmax=1.)
        # plt.contour(label[2], colors='r')
        # plt.subplot(235)
        # plt.axis('off')
        # plt.imshow(pred[3], cmap='gray', vmin=0., vmax=1.)
        # plt.contour(label[3], colors='r')
        # plt.subplot(233)
        # plt.axis('off')
        # plt.imshow(t2_crop, cmap='gray')
        # plt.contour(label[0], colors='r')
        # plt.contour(label[1], colors='y')
        # plt.contour(label[2], colors='b')
        # plt.contour(label[3], colors='g')
        # plt.show()

    plt.subplot(221)
    plt.title('aver: {:.3f}'.format(sum(dice2_list) / len(dice2_list)))
    plt.hist(dice2_list, bins=20)
    plt.subplot(222)
    plt.title('aver: {:.3f}'.format(sum(dice3_list) / len(dice3_list)))
    plt.hist(dice3_list, bins=20)
    plt.subplot(223)
    plt.title('aver: {:.3f}'.format(sum(dice4_list) / len(dice4_list)))
    plt.hist(dice4_list, bins=20)
    plt.subplot(224)
    plt.title('aver: {:.3f}'.format(sum(dice5_list) / len(dice5_list)))
    plt.hist(dice5_list, bins=20)
    plt.show()
    return dice1_list, dice2_list, dice3_list, dice4_list, dice5_list
Esempio n. 6
0
def ShoweResult(model_folder, data_type='train', num_pred=1, save_path=r''):
    if save_path and not os.path.exists(save_path):
        os.mkdir(save_path)
    result_folder = os.path.join(model_folder, 'Result')
    label_path = os.path.join(result_folder, '{}_label.npy'.format(data_type))
    if num_pred == 1:
        pred_path = os.path.join(result_folder, '{}_preds.npy'.format(data_type))

        label = np.load(label_path)
        pred = np.load(pred_path)

        for index in range(label.shape[0]):
            pred_index = ROIOneHot(np.argmax(pred[index], axis=0))

            plt.figure(figsize=(12, 4))

            plt.subplot(251)
            plt.axis('off')
            plt.imshow(label[index][0, ...], cmap='gray')

            plt.subplot(252)
            plt.axis('off')
            plt.imshow(label[index][1, ...], cmap='gray')

            plt.subplot(253)
            plt.axis('off')
            plt.imshow(label[index][2, ...], cmap='gray')

            plt.subplot(254)
            plt.axis('off')
            plt.imshow(label[index][3, ...], cmap='gray')

            plt.subplot(255)
            plt.axis('off')
            plt.imshow(label[index][4, ...], cmap='gray')

            plt.subplot(256)
            plt.title('{:.3f}'.format(Dice(pred_index[0], label[index][0])))
            plt.axis('off')
            plt.imshow(pred[index][0, ...], cmap='gray')

            plt.subplot(257)
            plt.title('{:.3f}'.format(Dice(pred_index[1], label[index][1])))
            plt.axis('off')
            plt.imshow(pred[index][1, ...], cmap='gray')

            plt.subplot(258)
            plt.axis('off')
            plt.title('{:.3f}'.format(Dice(pred_index[2], label[index][2])))
            plt.imshow(pred[index][2, ...], cmap='gray')

            plt.subplot(259)
            plt.axis('off')
            plt.title('{:.3f}'.format(Dice(pred_index[3], label[index][3])))
            plt.imshow(pred[index][3, ...], cmap='gray')

            plt.subplot(2, 5, 10)
            plt.axis('off')
            plt.title('{:.3f}'.format(Dice(pred_index[4], label[index][4])))
            plt.imshow(pred[index][4, ...], cmap='gray')
            if save_path:
                plt.savefig(os.path.join(save_path, 'test_{}.jpg'.format(index)))
                plt.close()
            else:
                plt.show()

    elif num_pred == 2:
        pred_path1 = os.path.join(result_folder, '{}_preds1.npy'.format(data_type))
        pred_path2 = os.path.join(result_folder, '{}_preds2.npy'.format(data_type))

        label = np.load(label_path)
        pred1 = np.load(pred_path1)
        pred2 = np.load(pred_path2)
        for index in range(label.shape[0]):
            pred_index = ROIOneHot(np.argmax(pred2[index], axis=0))
            plt.figure(figsize=(12, 4))

            ############################################################################################################
            plt.subplot(351)
            plt.axis('off')
            plt.imshow(label[index][0, ...], cmap='gray')
            plt.subplot(352)
            plt.axis('off')
            plt.imshow(label[index][1, ...], cmap='gray')
            plt.subplot(353)
            plt.axis('off')
            plt.imshow(label[index][2, ...], cmap='gray')
            plt.subplot(354)
            plt.axis('off')
            plt.imshow(label[index][3, ...], cmap='gray')
            plt.subplot(355)
            plt.axis('off')
            plt.imshow(label[index][4, ...], cmap='gray')

            ############################################################################################################
            plt.subplot(356)
            plt.axis('off')
            plt.imshow(pred1[index][0, ...], cmap='gray')
            plt.subplot(357)
            plt.axis('off')
            plt.imshow(pred1[index][1, ...], cmap='gray')
            plt.subplot(358)
            plt.axis('off')
            plt.imshow(pred1[index][2, ...], cmap='gray')

            ############################################################################################################
            plt.subplot(3, 5, 11)
            plt.title('{:.3f}'.format(Dice(pred_index[0], label[index][0])))
            plt.axis('off')
            plt.imshow(pred2[index][0, ...], cmap='gray')

            plt.subplot(3, 5, 12)
            plt.title('{:.3f}'.format(Dice(pred_index[1], label[index][1])))
            plt.axis('off')
            plt.imshow(pred2[index][1, ...], cmap='gray')

            plt.subplot(3, 5, 13)
            plt.title('{:.3f}'.format(Dice(pred_index[2], label[index][2])))
            plt.axis('off')
            plt.imshow(pred2[index][2, ...], cmap='gray')

            plt.subplot(3, 5, 14)
            plt.title('{:.3f}'.format(Dice(pred_index[3], label[index][3])))
            plt.axis('off')
            plt.imshow(pred2[index][3, ...], cmap='gray')

            plt.subplot(3, 5, 15)
            plt.title('{:.3f}'.format(Dice(pred_index[4], label[index][4])))
            plt.axis('off')
            plt.imshow(pred2[index][4, ...], cmap='gray')

            # plt.savefig()

            plt.show()
Esempio n. 7
0
def ShowResult(save_path):
    for case in os.listdir(case_folder):
        case_label = case[:case.index('_0000')] + '.nii.gz'
        case_path = os.path.join(case_folder, case)
        label_path = os.path.join(label_folder, case_label)
        predict_path = os.path.join(predict_folder, case_label)

        _, t2, _ = LoadImage(case_path)
        _, label, _ = LoadImage(label_path)
        label = ROIOneHot(label)
        _, prediction, _ = LoadImage(predict_path)
        prediction = ROIOneHot(prediction)

        bg_label = label[:, 0, :, :]
        pz_label = label[:, 1, :, :]
        cg_label = label[:, 2, :, :]
        u_label = label[:, 3, :, :]
        as_label = label[:, 4, :, :]

        bg_pred = prediction[:, 0, :, :]
        pz_pred = prediction[:, 1, :, :]
        cg_pred = prediction[:, 2, :, :]
        u_pred = prediction[:, 3, :, :]
        as_pred = prediction[:, 4, :, :]

        for index in range(label.shape[0]):
            plt.figure(figsize=(12, 4))

            plt.subplot(251)
            plt.axis('off')
            plt.imshow(bg_label[index], cmap='gray')

            plt.subplot(252)
            plt.axis('off')
            plt.imshow(pz_label[index], cmap='gray')

            plt.subplot(253)
            plt.axis('off')
            plt.imshow(cg_label[index], cmap='gray')

            plt.subplot(254)
            plt.axis('off')
            plt.imshow(u_label[index], cmap='gray')

            plt.subplot(255)
            plt.axis('off')
            plt.imshow(as_label[index], cmap='gray')

            plt.subplot(256)
            plt.title('{:.3f}'.format(Dice(bg_pred[index], bg_label[index])))
            plt.axis('off')
            plt.imshow(bg_pred[index], cmap='gray')

            plt.subplot(257)
            plt.title('{:.3f}'.format(Dice(pz_pred[index], pz_label[index])))
            plt.axis('off')
            plt.imshow(pz_pred[index], cmap='gray')

            plt.subplot(258)
            plt.axis('off')
            plt.title('{:.3f}'.format(Dice(cg_pred[index], cg_label[index])))
            plt.imshow(cg_pred[index], cmap='gray')

            plt.subplot(259)
            plt.axis('off')
            plt.title('{:.3f}'.format(Dice(u_pred[index], u_label[index])))
            plt.imshow(u_pred[index], cmap='gray')

            plt.subplot(2, 5, 10)
            plt.axis('off')
            plt.title('{:.3f}'.format(Dice(as_pred[index], as_label[index])))
            plt.imshow(as_pred[index], cmap='gray')
            if save_path:
                plt.savefig(
                    os.path.join(save_path, '{}_{}.jpg'.format(case, index)))
                plt.close()
            else:
                plt.show()