コード例 #1
0
ファイル: dataset.py プロジェクト: chenkarl/kits19
    def __init__(self,
                 case_num,
                 transform=None,
                 target_transform=None):  # root表示图片路径
        volume = []
        segmentation = []
        if case_num == -1:  # 如果等于-1,将所有数据集都读入
            volume = []
            segmentation = []
            for i in range(210):
                tmp_volume, tmp_segmentation = load_case(i)
                np.vstack((volume, tmp_volume))
                np.vstack((segmentation, tmp_segmentation))
        elif case_num < 210:
            volume, segmentation = load_case(case_num)
            kid_seg_ims = get_kid_img(segmentation.get_data())  # 取出肾脏分割图
        else:
            volume = load_volume(case_num)

        vol_ims = normalize(volume.get_data(), DEFAULT_HU_MAX,
                            DEFAULT_HU_MIN)  # 将原图归一化转为灰度图
        imgs = []
        if case_num == -1:
            for num in range(vol_ims.shape[0]):
                imgs.append([vol_ims[num], kid_seg_ims[num]])
        elif case_num < 210:
            for num in range(vol_ims.shape[0]):
                # if not np.all(kid_seg_ims[num]==0):
                imgs.append([vol_ims[num], kid_seg_ims[num]])
        else:
            for num in range(vol_ims.shape[0]):
                imgs.append([vol_ims[num], None])
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
コード例 #2
0
ファイル: visualize.py プロジェクト: Willzy-x/kits_experiment
def visualize(cid, seg, destination, ori, hu_min=DEFAULT_HU_MIN, hu_max=DEFAULT_HU_MAX, 
    k_color=DEFAULT_KIDNEY_COLOR, t_color=DEFAULT_TUMOR_COLOR,
    alpha=DEFAULT_OVERLAY_ALPHA):
    # Prepare output location
    out_path = Path(destination)
    if not out_path.exists():
        out_path.mkdir()  

    # Load segmentation and volume
    vol = load_volume(cid)
    vol = vol.get_data()
    if ori:
        seg = seg.get_data()
    else:
        seg = seg.astype(np.int32)
    
    # Convert to a visual format
    vol_ims = hu_to_grayscale(vol, hu_min, hu_max)
    seg_ims = class_to_color(seg, k_color, t_color)
    
    # Overlay the segmentation colors
    viz_ims = overlay(vol_ims, seg_ims, seg, alpha)

    # Save individual images to disk
    for i in range(viz_ims.shape[0]):
        fpath = out_path / ("{:05d}.png".format(i))
        scipy.misc.imsave(str(fpath), viz_ims[i])
コード例 #3
0
def predict_volume_split(model, configs, model_path, case_id, suffix='nii'):
    # ---------- Load Trained Model ------------------------------------------------------
    model.eval()
    model = nn.parallel.DataParallel(model, device_ids=range(0, 1))
    # ---------- visualize pipeline ------------------------------------------------------
    case_num = str(case_id).zfill(5)
    case_pred = []
    tensor_list = []
    print('Loading case_' + case_num)
    # ---------- Load volume -------------------------------------------------------------
    if suffix == 'nii':
        img = load_volume(case_id)
        img_shape = img.shape
        print("image shape:", img_shape)
        # ---------- Preprocess --------------------------------------------------------------
        img_data = img.get_fdata()
    else:
        img_data = Kits2019DataLoader3D.load_patient(
            os.path.join('/home/data_share/npy_data/', case_num))[0][0]
        img_shape = img_data.shape
        # img_data = img_data[None, None, :, :, :]
    # img_data = clip_img(img=img_data, low_bound=-200, high_bound=400)
    # img_data = normalize(img=img_data, method='feature')
    img_list, _, _, _ = crop_multi_slices(img_data, new_slice=32)
    print("loading tensor...")
    for j in range(len(img_list)):
        img_arr = img_list[j][None, None, :, :, :]
        tensor_list.append(torch.tensor(img_arr, dtype=torch.float32))
    img_tensor_shape = tensor_list[0].size()
    # ---------- Predict segmentation mask -----------------------------------------------
    with torch.no_grad():
        for k in range(len(tensor_list)):
            prediction = model(tensor_list[k])

            if isinstance(prediction, list):
                prediction = prediction[-1]

            pred_shape = prediction.size()

            if pred_shape != img_tensor_shape:
                prediction = F.interpolate(
                    prediction, size=img_tensor_shape[2:5], mode='trilinear')
                pred_shape = img_tensor_shape

            prediction = prediction.permute(0, 2, 3, 4, 1).contiguous()
            prediction = prediction.view(pred_shape[2],
                                         pred_shape[3], pred_shape[4], -1)  # 2 labels

            prediction = F.log_softmax(prediction, dim=-1)  # dim?
            prediction = prediction.cpu().numpy()
            print("before", prediction.shape)

            prediction = np.argmax(prediction, axis=-1)
            print("after", prediction.shape)

            case_pred.append(prediction)

    case_result = restore_slice(case_pred, ori_slice=img_shape[0], new_slice=32)
    return case_result
コード例 #4
0
def predict_volume(model, configs, model_path, case_id, suffix='nii'):
    # ---------- Load Trained Model ------------------------------------------------------
    model.eval()
    model = nn.parallel.DataParallel(model, device_ids=range(0, 1))
    # ---------- visualize pipeline ------------------------------------------------------
    case_num = str(case_id).zfill(5)
    case_pred = []
    tensor_list = []
    print('Loading case_' + case_num)
    # ---------- Load volume -------------------------------------------------------------
    if suffix == 'nii':
        img = load_volume(case_id)
        img_shape = img.shape
        # print("image shape:", img_shape)
        img_data = img.get_fdata()[None, None, :, :, :]
    else:
        img_data = Kits2019DataLoader3D.load_patient(os.path.join('/home/data_share/npy_data/', case_num))[0][0]
        img_data = img_data[None, None, :, :, :]
    # ---------- Preprocess --------------------------------------------------------------
    img_data = torch.tensor(img_data, dtype=torch.float32)
    img_shape = img_data.size()
    # ---------- Predict segmentation mask -----------------------------------------------
    with torch.no_grad():
        prediction = model(img_data)

        if isinstance(prediction, list):
            prediction = prediction[-1]

        pred_shape = prediction.size()

        if pred_shape != img_shape:
            prediction = F.interpolate(
                prediction, size=img_shape[2:5], mode='trilinear')
            pred_shape = img_shape

        prediction = prediction.permute(0, 2, 3, 4, 1).contiguous()
        prediction = prediction.view(pred_shape[2],
                                     pred_shape[3], pred_shape[4], -1)  # 2 labels

        if not configs['dice']:
            prediction = F.log_softmax(prediction, dim=-1)  # dim?
        else:
            prediction = F.softmax(prediction, dim=-1)
        prediction = prediction.cpu().numpy()
        print("before", prediction.shape)

        prediction = np.argmax(prediction, axis=-1)
        print("after", prediction.shape)

    return prediction
コード例 #5
0
def draw_contour_volume(cid, pred=None, path='./pics'):
    img_data = load_volume(cid).get_fdata()
    seg_data = load_segmentation(cid).get_fdata()

    seg_data = seg_data.astype(np.uint8)
    pred = pred.astype(np.uint8)

    new_path = os.path.join(path, 'pics')
    os.mkdir(new_path)

    for i, sli in enumerate(img_data):
        if pred is not None:
            result = draw_uni_contour(img_data[i], seg_data[i], pred[i])
        else:
            result = draw_uni_contour(img_data[i], seg_data[i])

        pic_name = str(i).zfill(4) + '.png'
        cv2.imwrite(os.path.join(new_path, pic_name), result)
コード例 #6
0
def predict_volume(model, configs, model_path, case_id):
    #---------- Load Trained Model ------------------------------------------------------
    model.eval()
    model = nn.parallel.DataParallel(model, device_ids=range(0, 2))
    #---------- visualize pipeline ------------------------------------------------------
    case_num = str(case_id).zfill(5)
    case_pred = []
    tensor_list = []
    print('Loading case_' + case_num)
    #---------- Load volume -------------------------------------------------------------
    img = load_volume(case_id)
    img_shape = img.shape
    print("image shape:", img_shape)
    #---------- Preprocess --------------------------------------------------------------
    img_data = img.get_fdata()
    img_shape = img_data.shape
    #---------- Predict segmentation mask -----------------------------------------------
    with torch.no_grad():
        prediction = model(img_data)

        if isinstance(prediction, list):
            prediction = prediction[-1]

        pred_shape = prediction.size()

        if pred_shape != img_shape:
            prediction = F.interpolate(prediction,
                                       size=img_shape[2:5],
                                       mode='trilinear')
            pred_shape = img_shape

        prediction = prediction.permute(0, 2, 3, 4, 1).contiguous()
        prediction = prediction.view(pred_shape[2], pred_shape[3],
                                     pred_shape[4], -1)  # 2 labels

        prediction = F.log_softmax(prediction, dim=1)  # dim?
        prediction = prediction.cpu().numpy()
        print("before", prediction.shape)

        prediction = np.argmax(prediction, axis=1)
        print("after", prediction.shape)

    return prediction
コード例 #7
0
def predict_volume_slide_window(model, configs, model_path, case_id, patch_size=(160, 160, 128), strides=(40, 40, 20),
                                suffix='nii', do_mirror=True):
    # ---------- Load Trained Model ------------------------------------------------------
    model.eval()
    model = nn.parallel.DataParallel(model, device_ids=range(0, 1))
    # ---------- visualize pipeline ------------------------------------------------------
    case_num = str(case_id).zfill(5)
    print('Loading case_' + case_num)
    # ---------- Load volume -------------------------------------------------------------
    if suffix == 'nii':
        img = load_volume(case_id)
        img_shape = img.shape
        img_data = img.get_fdata()[None, None, :, :, :]
    else:
        img_data = Kits2019DataLoader3D.load_patient(os.path.join('/home/data_share/npy_data/', case_num))[0][0]
        img_data = img_data[None, None, :, :, :]
    c = np.zeros((img_data.shape))
    # ---------- Preprocess --------------------------------------------------------------
    img_data = torch.tensor(img_data, dtype=torch.float32)
    img_shape = img_data.size()
    full_pred = torch.zeros((img_shape))
    background_color = img_data.min()
    # ------------------------------------------------------------------------------------
    d, h, w = img_shape[2], img_shape[3], img_shape[4]
    with torch.no_grad():
        i = 0
        while i < d:
            j = 0
            while j < h:
                k = 0
                while k < w:
                    # ---------- Predict segmentation mask -----------------------------------------------
                    data_window = img_data[:, :, i:i + patch_size[0], j:j + patch_size[1], k:k + patch_size[2]]

                    if data_window.shape[2:] != patch_size:
                        empty_data_window = torch.zeros((patch_size))
                        empty_data_window = empty_data_window[None, None, :, :, :] + background_color
                        dw, hw, ww = data_window.shape[2:]
                        empty_data_window[:, :, :dw, :hw, :ww] = data_window[:, :, :, :, :]

                    mirrored_data, mirrored_axes = mirror_when_test(data_window)

                    output = model(data_window)
                    output = F.interpolate(output, size=patch_size, mode='trilinear')

                    pred = output.permute(0, 2, 3, 4, 1).contiguous()
                    pred = pred.view(patch_size[0], patch_size[1], patch_size[2], -1)

                    if not configs['dice']:
                        pred = F.log_softmax(pred, dim=-1)  # dim?
                    else:
                        pred = F.softmax(pred, dim=-1)

                    pred = torch.argmax(pred, dim=-1)
                    pred = pred[None, None, :, :, :]

                    full_pred[:, :, i:i + patch_size[0], j:j + patch_size[1], k:k + patch_size[2]] += \
                        pred[:, :, :min(d - i, patch_size[0]), :min(h - j, patch_size[1]),
                        :min(w - k, patch_size[2])].float().cpu()

                    c[:, :, i:i + patch_size[0], j:j + patch_size[1], k:k + patch_size[2]] += 1

                    if do_mirror:
                        mirrored_data = mirrored_data.float()
                        output = model(mirrored_data)

                        output = F.interpolate(output, size=patch_size, mode='trilinear')

                        pred = output.permute(0, 2, 3, 4, 1).contiguous()
                        pred = pred.view(patch_size[0], patch_size[1], patch_size[2], -1)

                        if not configs['dice']:
                            pred = F.log_softmax(pred, dim=-1)  # dim?
                        else:
                            pred = F.softmax(pred, dim=-1)

                        pred = torch.argmax(pred, dim=-1)
                        pred = pred[None, None, :, :, :].cpu()
                        pred = reverse_mirror(pred, mirrored_axes)

                        full_pred[:, :, i:i + patch_size[0], j:j + patch_size[1], k:k + patch_size[2]] += \
                            pred[:, :, :min(d - i, patch_size[0]), :min(h - j, patch_size[1]),
                            :min(w - k, patch_size[2])].float()

                        c[:, :, i:i + patch_size[0], j:j + patch_size[1], k:k + patch_size[2]] += 1

                    if k + patch_size[2] >= w:
                        break
                    else:
                        k += strides[2]
                if j + patch_size[1] >= h:
                    break
                else:
                    j += strides[1]
            if i + patch_size[0] >= d:
                break
            else:
                i += strides[0]

        full_pred = np.round(full_pred.numpy() / c)

    return full_pred.squeeze()
コード例 #8
0
def predict_volume_sw(model, configs, model_path, case_id, patch_size=(160, 160, 128), strides=(20, 20, 10),
                      suffix='nii', batch_size=4, mirror=False):
    # ---------- Load Trained Model ------------------------------------------------------
    model.eval()
    model = nn.parallel.DataParallel(model, device_ids=range(0, 1))
    # ---------- visualize pipeline ------------------------------------------------------
    case_num = str(case_id).zfill(5)
    print('Loading case_' + case_num)
    # ---------- Load volume -------------------------------------------------------------
    if suffix == 'nii':
        img = load_volume(case_id)
        img_shape = img.shape
        # print("image shape:", img_shape)
        img_data = img.get_fdata()
    else:
        img_data = Kits2019DataLoader3D.load_patient(os.path.join('/home/data_share/npy_data/', case_num))[0][0]
        img_data = img_data
    # ---------- Preprocess --------------------------------------------------------------
    img_shape = img_data.shape
    if mirror:
        mirrored = np.flip(img_data.copy(), axis=(0, 1, 2))
        m_patches = image.extract_patches(mirrored, patch_size, strides)
        m_patches_shape = m_patches.shape

    patches = image.extract_patches(img_data, patch_size, strides)
    patches_shape = patches.shape

    patches = patches.reshape((-1, 1, patch_size[0], patch_size[1], patch_size[2]))
    patches = torch.tensor(patches, dtype=torch.float32)
    # ---------- Predict segmentation mask -----------------------------------------------
    with torch.no_grad():
        pred_list = []
        nb_batches = int(np.ceil(patches.shape[0] / float(batch_size)))  # batch size
        for batch_id in range(nb_batches):
            batch_index_1, batch_index_2 = batch_id * batch_size, (batch_id + 1) * batch_size
            data = patches[batch_index_1:batch_index_2]
            prediction = model(data)

            if isinstance(prediction, list):
                prediction = prediction[-1]

            pred_shape = prediction.size()

            if pred_shape[2:] != patch_size:
                prediction = F.interpolate(
                    prediction, size=patch_size, mode='trilinear')
                pred_shape[2:] = patch_size

            prediction = prediction.permute(0, 2, 3, 4, 1).contiguous()
            prediction = prediction.view(pred_shape[0], pred_shape[2],
                                         pred_shape[3], pred_shape[4], -1)  # 2 labels

            if not configs['dice']:
                prediction = F.log_softmax(prediction, dim=-1)  # dim?
            else:
                prediction = F.softmax(prediction, dim=-1)

            prediction = prediction.cpu().numpy()
            print("before", prediction.shape)

            prediction = np.argmax(prediction, axis=-1)
            print("after", prediction.shape)
            pred_list.append(prediction)

        final_result = reduce(lambda x, y: np.concatenate((x, y), axis=0), pred_list)
        final_result = final_result.reshape(patches_shape)

        if mirror:
            m_patches = m_patches.reshape((-1, 1, patch_size[0], patch_size[1], patch_size[2]))
            m_patches = torch.tensor(m_patches, dtype=torch.float32)
            pred_list = []
            nb_batches = int(np.ceil(m_patches.shape[0] / float(batch_size)))  # batch size
            for batch_id in range(nb_batches):
                batch_index_1, batch_index_2 = batch_id * batch_size, (batch_id + 1) * batch_size
                data = m_patches[batch_index_1:batch_index_2]
                prediction = model(data)

                if isinstance(prediction, list):
                    prediction = prediction[-1]

                pred_shape = prediction.size()

                if pred_shape[2:] != patch_size:
                    prediction = F.interpolate(
                        prediction, size=patch_size, mode='trilinear')
                    pred_shape[2:] = patch_size

                prediction = prediction.permute(0, 2, 3, 4, 1).contiguous()
                prediction = prediction.view(pred_shape[0], pred_shape[2],
                                             pred_shape[3], pred_shape[4], -1)  # 2 labels

                if not configs['dice']:
                    prediction = F.log_softmax(prediction, dim=-1)  # dim?
                else:
                    prediction = F.softmax(prediction, dim=-1)

                prediction = prediction.cpu().numpy()
                print("before", prediction.shape)

                prediction = np.argmax(prediction, axis=-1)
                print("after", prediction.shape)
                pred_list.append(prediction)

            m_final_result = reduce(lambda x, y: np.concatenate((x, y), axis=0), pred_list)
            m_final_result = m_final_result.reshape(m_patches_shape)

            prediction = reconstruct_labels(final_result, img_shape, 3, strides, m_final_result)

        else:
            prediction = reconstruct_labels(final_result, img_shape, 3, strides)
    print(prediction.shape)

    return prediction
コード例 #9
0
def visualizetest(cid,
                  source,
                  destination,
                  hu_min=DEFAULT_HU_MIN,
                  hu_max=DEFAULT_HU_MAX,
                  k_color=DEFAULT_KIDNEY_COLOR,
                  t_color=DEFAULT_TUMOR_COLOR,
                  alpha=DEFAULT_OVERLAY_ALPHA,
                  plane=DEFAULT_PLANE):

    plane = plane.lower()
    filename = cid
    plane_opts = ["axial", "coronal", "sagittal"]
    if plane not in plane_opts:
        raise ValueError(("Plane \"{}\" not understood. "
                          "Must be one of the following\n\n\t{}\n").format(
                              plane, plane_opts))

    # Prepare output location
    out_path = Path(destination)
    image = destination + "\\image"
    image_path = Path(image)
    if not out_path.exists():
        out_path.mkdir()
        # test
    if not image_path.exists():
        image_path.mkdir()
    # Load segmentation and volume
    vol = load_volume(source, cid)
    spacing = vol.affine
    vol = vol.get_data()
    # seg = seg.get_data()
    # seg = seg.astype(np.int32)

    # Convert to a visual format
    vol_ims = hu_to_grayscale(vol, hu_min, hu_max)
    # seg_ims = class_to_color(seg, k_color, t_color)
    # Save individual images to disk
    if plane == plane_opts[0]:
        # Overlay the segmentation colors
        # alpha
        # viz_ims0 = overlay(vol_ims, seg_ims, seg, 0)
        # viz_ims1 = overlay(seg_ims, seg_ims, seg, 1)
        for i in range(vol_ims.shape[0]):
            # fpath = label_path / ("{}_{:05d}.png".format(filename,i))
            # scipy.misc.imsave(str(fpath), viz_ims1[i])
            fpath = image_path / ("{}_{:05d}.png".format(filename, i))
            scipy.misc.imsave(str(fpath), vol_ims[i])

    if plane == plane_opts[1]:
        # I use sum here to account for both legacy (incorrect) and
        # fixed affine matrices
        spc_ratio = np.abs(np.sum(spacing[2, :])) / np.abs(
            np.sum(spacing[0, :]))
        for i in range(vol_ims.shape[1]):
            fpath = out_path / ("{:05d}.png".format(i))
            vol_im = scipy.misc.imresize(
                vol_ims[:, i, :],
                (int(vol_ims.shape[0] * spc_ratio), int(vol_ims.shape[2])),
                interp="bicubic")
            seg_im = scipy.misc.imresize(
                seg_ims[:, i, :],
                (int(vol_ims.shape[0] * spc_ratio), int(vol_ims.shape[2])),
                interp="nearest")
            sim = scipy.misc.imresize(
                seg[:, i, :],
                (int(vol_ims.shape[0] * spc_ratio), int(vol_ims.shape[2])),
                interp="nearest")
            viz_im = overlay(vol_im, seg_im, sim, alpha)
            scipy.misc.imsave(str(fpath), viz_im)

    if plane == plane_opts[2]:
        # I use sum here to account for both legacy (incorrect) and
        # fixed affine matrices
        spc_ratio = np.abs(np.sum(spacing[2, :])) / np.abs(
            np.sum(spacing[1, :]))
        for i in range(vol_ims.shape[2]):
            fpath = out_path / ("{:05d}.png".format(i))
            vol_im = scipy.misc.imresize(
                vol_ims[:, :, i],
                (int(vol_ims.shape[0] * spc_ratio), int(vol_ims.shape[1])),
                interp="bicubic")
            seg_im = scipy.misc.imresize(
                seg_ims[:, :, i],
                (int(vol_ims.shape[0] * spc_ratio), int(vol_ims.shape[1])),
                interp="nearest")
            sim = scipy.misc.imresize(
                seg[:, :, i],
                (int(vol_ims.shape[0] * spc_ratio), int(vol_ims.shape[1])),
                interp="nearest")
            viz_im = overlay(vol_im, seg_im, sim, alpha)
            scipy.misc.imsave(str(fpath), viz_im)
コード例 #10
0
def preprocessing(case_nr, slice_number_to_print):

    from starter_code.utils import load_volume
    import numpy as np
    import nibabel
    import matplotlib.pyplot as plt
    from skimage import filters
    from skimage.morphology import disk, square
    from skimage import exposure
    from skimage.morphology import disk, closing, opening, remove_small_holes
    from skimage.color import label2rgb, rgb2gray
    import numexpr as ne
    from joblib import Parallel, delayed
    import multiprocessing
    import os

    def normalize(image):
        minimum = np.min(image)
        maximum = np.max(image)
        result = (image - minimum) / (maximum - minimum)
        return result

    volume = load_volume(case_nr)
    #data_seg = segment.get_fdata()
    data = volume.get_fdata()
    data_preprocessed_vol = np.zeros(data.shape)
    data_masks = np.zeros(data.shape)

    slice_nr_list = list(range(data.shape[0]))  #List to iterate by

    #Definition of preprocessing per one slice
    def slice_pre(data_slice):

        data_normalized = normalize(data_slice)
        data_oryg = data_normalized.copy()
        data_hist = exposure.equalize_adapthist(data_normalized,
                                                clip_limit=0.1)

        data_median = filters.median(data_hist, selem=disk(3))
        data_median = exposure.adjust_gamma(data_median, gamma=10)

        thresholds = filters.threshold_multiotsu(data_median)
        regions = np.digitize(data_normalized, bins=thresholds)

        data_normalized[regions != 2] = 0

        data_pre = exposure.equalize_hist(data_normalized)

        thresholds = filters.threshold_multiotsu(data_pre)

        regions = np.digitize(data_pre, bins=thresholds)
        data_mask = label2rgb(regions)
        data_mask = rgb2gray(data_mask)
        data_mask[data_mask != np.max(data_mask)] = 0
        data_mask[data_mask == np.max(data_mask)] = 1
        data_mask = closing(data_mask, selem=disk(3)).astype(int)
        data_mask = remove_small_holes(data_mask, area_threshold=300)

        data_eq = exposure.equalize_adapthist(data_oryg, clip_limit=0.15)
        data_preprocessed = data_eq * data_mask
        data_preprocessed = opening(data_preprocessed, selem=disk(3))

        return data_preprocessed, data_mask

    num_cores = multiprocessing.cpu_count()
    data_list = Parallel(n_jobs=num_cores)(
        delayed(slice_pre)(data[slice_nr, :, :]) for slice_nr in slice_nr_list)
    data_arr = np.array(data_list)
    data_preprocessed_vol = data_arr[:, 0, :, :]
    data_masks = data_arr[:, 1, :, :]

    path = os.getcwd()
    new_path = os.path.join(path, "preprocessed")
    os.chdir(new_path)

    img_pre = nibabel.Nifti1Image(data_preprocessed_vol, volume.affine)
    nibabel.save(img_pre, 'case{}_preprocessed.nii.gz'.format(case_nr))
    #Saving to file
    os.chdir(path)

    #This part prints slice if user specifies slice number
    if slice_number_to_print != -1:
        fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(10, 3.5))
        fig.suptitle('Case: {}, Slice: {}'.format(case_nr,
                                                  slice_number_to_print),
                     fontsize=16)
        ax[0].imshow(data[slice_number_to_print, :, :], cmap='gray')
        ax[0].set_title('Original')
        ax[0].axis('off')

        ax[1].imshow(data_masks[slice_number_to_print, :, :], cmap='gray')
        ax[1].set_title('mask')
        ax[1].axis('off')

        ax[2].imshow(data_preprocessed_vol[slice_number_to_print, :, :],
                     cmap='gray')
        ax[2].set_title('preprocessed')
        ax[2].axis('off')

        plt.subplots_adjust()

        plt.show()