def load_without_size_preprocessing(input_folder,
                                    cv_fold_num,
                                    train_test,
                                    idx):
    
    file_list = get_file_list(input_folder,
                              cv_fold_num)
    
    image_file = file_list[train_test][idx]
    
    # ============
    # read image and normalize it to be between 0 and 1
    # ============
    image_dat = utils.load_nii(image_file)
    image = image_dat[0].copy()
    image = utils.normalise_image(image, norm_type='div_by_max')
    
    # ============
    # read label and set RV label to 1, others to 0
    # ============
    label_file = image_file.split('_n4.nii.gz')[0] + '_gt.nii.gz'
    label_dat = utils.load_nii(label_file)
    label = label_dat[0].copy()
    label[label!=1] = 0
        
    return image, label
Пример #2
0
def load_without_size_preprocessing(preproc_folder,
                                    patient_id):
                    
    # ==================
    # read bias corrected image and ground truth segmentation
    # ==================
    filepath_bias_corrected_nii_format = preproc_folder + 'Case' + patient_id + '_n4.nii.gz'
    filepath_seg_nii_format = preproc_folder + 'Case' + patient_id + '_segmentation.nii.gz'
    
    # ================================    
    # read bias corrected image
    # ================================    
    image = utils.load_nii(filepath_bias_corrected_nii_format)[0]

    # ================================    
    # normalize the image
    # ================================    
    image = utils.normalise_image(image, norm_type='div_by_max')

    # ================================    
    # read the labels
    # ================================    
    label = utils.load_nii(filepath_seg_nii_format)[0]            
    
    # ================================    
    # skimage io with simple ITKplugin was used to read the images in the convert_to_nii_and_correct_bias_field function.
    # this lead to the arrays being read as z-x-y
    # move the axes appropriately, so that the resolution read above is correct for the corresponding axes.
    # ================================    
    image = np.swapaxes(np.swapaxes(image, 0, 1), 1, 2)
    label = np.swapaxes(np.swapaxes(label, 0, 1), 1, 2)
    
    return image, label
Пример #3
0
def load_without_size_preprocessing(input_folder, site_name, idx, depth):

    # ========================
    # read the filenames
    # ========================
    filenames = sorted(glob.glob(input_folder + site_name + '/*/'))

    # ==================
    # get file paths
    # ==================
    patient_name, image_path, label_path = get_image_and_label_paths(
        filenames[idx])

    # ============
    # read the image and normalize it to be between 0 and 1
    # ============
    image, _, image_hdr = utils.load_nii(image_path)
    image = np.swapaxes(
        image, 1, 2
    )  # swap axes 1 and 2 -> this allows appending along axis 2, as in other datasets

    # ==================
    # read the label file
    # ==================
    label, _, _ = utils.load_nii(label_path)
    label = np.swapaxes(
        label, 1, 2
    )  # swap axes 1 and 2 -> this allows appending along axis 2, as in other datasets
    label = utils.group_segmentation_classes(
        label)  # group the segmentation classes as required

    # ============
    # create a segmentation mask and use it to get rid of the skull in the image
    # ============
    label_mask = np.copy(label)
    label_mask[label > 0] = 1
    image = image * label_mask

    # ==================
    # crop out some portion of the image, which are all zeros (rough registration via visual inspection)
    # ==================
    if site_name is 'CALTECH':
        image = image[:, 80:, :]
        label = label[:, 80:, :]
    elif site_name is 'STANFORD':
        image, label = center_image_and_label(image, label)

    # ==================
    # crop volume along z axis (as there are several zeros towards the ends)
    # ==================
    image = utils.crop_or_pad_volume_to_size_along_z(image, depth)
    label = utils.crop_or_pad_volume_to_size_along_z(label, depth)

    # ==================
    # normalize the image
    # ==================
    image = utils.normalise_image(image, norm_type='div_by_max')

    return image, label
def read_label(label_folder_path, shape, ed_es_diff, nifti_available=True):

    # ed_es_diff: number of slices between ED and ES

    nifti_lbl_path = label_folder_path[:label_folder_path.rfind('/') + 1]

    if nifti_available is False:

        label_ED = np.zeros(shape, dtype=np.uint8)
        label_ES = np.zeros(shape, dtype=np.uint8)

        for text_file in os.listdir(label_folder_path):

            if 'icontour' in text_file:

                text_file_path = os.path.join(label_folder_path, text_file)

                slice_id = int(text_file[4:8])

                slice_labels = open(text_file_path, "r")

                x = []
                y = []
                for l in slice_labels.readlines():
                    row = l.split()
                    x.append(int(float(row[1])))
                    y.append(int(float(row[0])))
                slice_labels.close()

                # fit a polygon to the points - this provides a way to make a binary segmentation mask
                xx, yy = polygon(x, y)

                # set the pixels inside the mask as 1
                if slice_id % 20 is 0:  # ED
                    label_ED[xx, yy, slice_id // 20] = 1

                elif (slice_id - ed_es_diff) % 20 is 0:  # ES
                    label_ES[xx, yy, (slice_id - ed_es_diff) // 20] = 1

        # ================================
        # save as nifti
        # ================================
        utils.save_nii(img_path=nifti_lbl_path + 'lbl_ED.nii.gz',
                       data=label_ED,
                       affine=np.eye(4))
        utils.save_nii(img_path=nifti_lbl_path + 'lbl_ES.nii.gz',
                       data=label_ES,
                       affine=np.eye(4))

    # ================================
    # read nifti labels
    # ================================
    label_ED = utils.load_nii(img_path=nifti_lbl_path + 'lbl_ED.nii.gz')[0]
    label_ES = utils.load_nii(img_path=nifti_lbl_path + 'lbl_ES.nii.gz')[0]

    return label_ED, label_ES
def read_image(image_folder_path, ed_es_diff, nifti_available=True):

    # ed_es_diff: number of slices between ED and ES
    px, py, pz, nx, ny, nz, list_of_dicom_filenames, pix_dtype = get_image_details(
        image_folder_path)
    nifti_img_path = image_folder_path[:image_folder_path.rfind('/') + 1]

    if nifti_available is False:

        imgDims = (nx, ny, nz)
        img_ED = np.zeros(imgDims, dtype=pix_dtype)
        img_ES = np.zeros(imgDims, dtype=pix_dtype)

        # ================================
        # read dicom series and create image volume
        # ================================
        for zz in range(len(list_of_dicom_filenames)):

            ds = dicom.read_file(list_of_dicom_filenames[zz])

            slice_id = int(list_of_dicom_filenames[zz][-8:-4])

            if slice_id % 20 is 0:  # ED
                img_ED[:, :, slice_id // 20] = ds.pixel_array

            elif (slice_id - ed_es_diff) % 20 is 0:  # ES
                img_ES[:, :, (slice_id - ed_es_diff) // 20] = ds.pixel_array

        # ================================
        # save as nifti, this sets the affine transformation as an identity matrix
        # ================================
        utils.save_nii(img_path=nifti_img_path + 'img_ED.nii.gz',
                       data=img_ED,
                       affine=np.eye(4))
        utils.save_nii(img_path=nifti_img_path + 'img_ES.nii.gz',
                       data=img_ES,
                       affine=np.eye(4))

        # ================================
        # do bias field correction
        # ================================
        for e in ['ED', 'ES']:
            input_img = nifti_img_path + 'img_' + e + '.nii.gz'
            output_img = nifti_img_path + 'img_' + e + '_n4.nii.gz'
            subprocess.call([
                "/usr/bmicnas01/data-biwi-01/bmicdatasets/Sharing/N4_th",
                input_img, output_img
            ])

    # ================================
    # read bias corrected image
    # ================================
    img_ED = utils.load_nii(img_path=nifti_img_path + 'img_ED_n4.nii.gz')[0]
    img_ES = utils.load_nii(img_path=nifti_img_path + 'img_ES_n4.nii.gz')[0]

    return img_ED, img_ES, px, py, pz
def compute_metrics_on_directories_raw(dir_gt, dir_pred):
    """
    Calculates all possible metrics (the ones from the metrics script as well as
    hausdorff and average symmetric surface distances)

    :param dir_gt: Directory of the ground truth segmentation maps.
    :param dir_pred: Directory of the predicted segmentation maps.
    :return:
    """

    lst_gt = sorted(glob(os.path.join(dir_gt, '*')), key=natural_order)
    lst_pred = sorted(glob(os.path.join(dir_pred, '*')), key=natural_order)

    res = []
    cardiac_phase = []
    file_names = []

    measure_names = [
        'Dice LV', 'Volume LV', 'Err LV(ml)', 'Dice RV', 'Volume RV',
        'Err RV(ml)', 'Dice MYO', 'Volume MYO', 'Err MYO(ml)', 'Hausdorff LV',
        'Hausdorff RV', 'Hausdorff Myo', 'ASSD LV', 'ASSD RV', 'ASSD Myo'
    ]

    res_mat = np.zeros((len(lst_gt), len(measure_names)))

    ind = 0
    for p_gt, p_pred in zip(lst_gt, lst_pred):
        if os.path.basename(p_gt) != os.path.basename(p_pred):
            raise ValueError("The two files don't have the same name"
                             " {}, {}.".format(os.path.basename(p_gt),
                                               os.path.basename(p_pred)))

        gt, _, header = utils.load_nii(p_gt)
        pred, _, _ = utils.load_nii(p_pred)
        zooms = header.get_zooms()
        res.append(metrics(gt, pred, zooms))
        cardiac_phase.append(
            os.path.basename(p_gt).split('.nii.gz')[0].split('_')[-1])

        file_names.append(os.path.basename(p_pred))

        res_mat[ind, :9] = metrics(gt, pred, zooms)

        for ii, struc in enumerate([3, 1, 2]):

            gt_binary = (gt == struc) * 1
            pred_binary = (pred == struc) * 1

            res_mat[ind, 9 + ii] = 0
            res_mat[ind, 12 + ii] = 0

        ind += 1

    return res_mat, cardiac_phase, measure_names, file_names
def load_without_size_preprocessing(input_folder, idx, labeller):

    # ===============================
    # read all the patient folders from the base input folder
    # ===============================
    folder_list = []
    for folder in os.listdir(input_folder):
        folder_path = os.path.join(input_folder, folder)
        if os.path.isdir(folder_path) and 't2_tse_tra.nii.gz' in os.listdir(
                folder_path):
            if 'segmentation_' + labeller + '.nii.gz' in os.listdir(
                    folder_path
            ) or 'segmentation_tra_' + labeller + '.nii.gz' in os.listdir(
                    folder_path):
                folder_list.append(folder_path)

    # ==================
    # read the image file
    # ==================
    image, _, _ = utils.load_nii(folder_list[idx] + '/t2_tse_tra_n4.nii.gz')
    # ============
    # normalize the image to be between 0 and 1
    # ============
    image = utils.normalise_image(image, norm_type='div_by_max')

    # ==================
    # read the label file
    # ==================
    if 'segmentation_' + labeller + '.nii.gz' in os.listdir(folder_list[idx]):
        label, _, _ = utils.load_nii(folder_list[idx] + '/segmentation_' +
                                     labeller + '.nii.gz')
    elif 'segmentation_tra_' + labeller + '.nii.gz' in os.listdir(
            folder_list[idx]):
        label, _, _ = utils.load_nii(folder_list[idx] + '/segmentation_tra_' +
                                     labeller + '.nii.gz')
    # ==================
    # remove extra label from some images
    # ==================
    label[label > 2] = 0

    return image, label


# ===============================================================
# End of file
# ===============================================================
Пример #8
0
def count_slices(folder_list, idx_start, idx_end):

    num_slices = 0
    for idx in range(idx_start, idx_end):
        image, _, _ = utils.load_nii(folder_list[idx] + '/t2_tse_tra.nii.gz')
        num_slices = num_slices + image.shape[2]

    return num_slices
Пример #9
0
def count_slices(folders_list, idx_start, idx_end, depth):

    num_slices = 0

    for idx in range(idx_start, idx_end):

        _, image_path, _ = get_image_and_label_paths(folders_list[idx])
        image, _, _ = utils.load_nii(image_path)
        num_slices = num_slices + depth  # the number of slices along the append axis will be fixed to this number

    return num_slices
def load_without_size_preprocessing(input_folder,
                                    cv_fold_num,
                                    train_test,
                                    idx):
    
    # ===============================
    # read all the patient folders from the base input folder
    # ===============================
    image_folder = os.path.join(input_folder, 'Prostate-3T')
    label_folder = os.path.join(input_folder, 'NCI_ISBI_Challenge-Prostate3T_Training_Segmentations')
    folder_list = get_patient_folders(image_folder,
                                      folder_base='Prostate3T-01',
                                      cv_fold_number = cv_fold_num)
    folder = folder_list[train_test][idx]

    # ==================
    # make a list of all dcm images for this subject
    # ==================                        
    lstFilesDCM = []  # create an empty list
    for dirName, subdirList, fileList in os.walk(folder):
        for filename in fileList:
            if ".dcm" in filename.lower():  # check whether the file's DICOM
                lstFilesDCM.append(os.path.join(dirName, filename))
                
    # ==================
    # read bias corrected image
    # ==================
    nifti_img_path = lstFilesDCM[0][:lstFilesDCM[0].rfind('/')+1]
    image = utils.load_nii(img_path = nifti_img_path + 'img_n4.nii.gz')[0]

    # ============
    # normalize the image to be between 0 and 1
    # ============
    image = utils.normalise_image(image, norm_type='div_by_max')

    # ==================
    # read the label file
    # ==================        
    label = utils.load_nii(img_path = nifti_img_path + 'lbl.nii.gz')[0]
    
    return image, label
Пример #11
0
def load_without_size_preprocessing(input_folder, idx, protocol,
                                    preprocessing_folder, depth):

    # ========================
    # read the filenames
    # ========================
    filenames = sorted(glob.glob(input_folder + '*.zip'))

    # ==================
    # get file paths
    # ==================
    patient_name, image_path, label_path = get_image_and_label_paths(
        filenames[idx], protocol, preprocessing_folder)

    # ============
    # read the image and normalize it to be between 0 and 1
    # ============
    image, _, image_hdr = utils.load_nii(image_path)
    image = np.swapaxes(
        image, 1, 2
    )  # swap axes 1 and 2 -> this allows appending along axis 2, as in other datasets
    image = utils.normalise_image(image, norm_type='div_by_max')

    # ==================
    # read the label file
    # ==================
    label, _, _ = utils.load_nii(label_path)
    label = np.swapaxes(
        label, 1, 2
    )  # swap axes 1 and 2 -> this allows appending along axis 2, as in other datasets
    label = utils.group_segmentation_classes(
        label)  # group the segmentation classes as required

    # ==================
    # crop volume along z axis (as there are several zeros towards the ends)
    # ==================
    image = utils.crop_or_pad_volume_to_size_along_z(image, depth)
    label = utils.crop_or_pad_volume_to_size_along_z(label, depth)

    return image, label
Пример #12
0
def count_slices(filenames, idx_start, idx_end, protocol, preprocessing_folder,
                 depth):

    num_slices = 0

    for idx in range(idx_start, idx_end):

        _, image_path, _ = get_image_and_label_paths(filenames[idx], protocol,
                                                     preprocessing_folder)

        image, _, _ = utils.load_nii(image_path)

        # num_slices = num_slices + image.shape[1] # will append slices along axes 1
        num_slices = num_slices + depth  # the number of slices along the append axis will be fixed to this number to crop out zeros

    return num_slices
Пример #13
0
def count_slices_and_patient_ids_list(input_folder,
                                      cv_fold_number):

    num_slices = {'train': 0, 'test': 0, 'validation': 0}       
    patient_ids_list = {'train': [], 'test': [], 'validation': []}
    
    # we know that there are 50 subjects in this dataset: Case00 through till Case49        
    for dirName, subdirList, fileList in os.walk(input_folder):               
        
        for filename in fileList:
            
            if re.match(r'Case\d\d.nii.gz', filename):
                
                patient_id = filename[4:6]
                train_test = test_train_val_split(int(patient_id), cv_fold_number)
                filepath = input_folder + '/' + filename
                patient_ids_list[train_test].append(patient_id)
                img = utils.load_nii(filepath)[0]
                num_slices[train_test] += img.shape[0]               

    return num_slices, patient_ids_list
def prepare_data(input_image_folder, input_mask_folder, output_file, size,
                 target_resolution):
    '''
    Main function that prepares a dataset from the raw challenge data to an hdf5 dataset
    '''

    hdf5_file = h5py.File(output_file, "w")

    expert_list = [
        'Readings_AH', 'Readings_EK', 'Readings_KC', 'Readings_KS',
        'Readings_OD', 'Readings_UM'
    ]
    num_annotators = len(expert_list)

    logging.info('Counting files and parsing meta data...')
    patient_id_list = {'test': [], 'train': [], 'validation': []}

    image_file_list = {'test': [], 'train': [], 'validation': []}
    mask_file_list = {'test': [], 'train': [], 'validation': []}

    num_slices = {'test': 0, 'train': 0, 'validation': 0}

    logging.info('Counting files and parsing meta data...')

    for folder in os.listdir(input_image_folder):

        folder_path = os.path.join(input_image_folder, folder)
        if os.path.isdir(folder_path) and folder.startswith('888'):

            patient_id = int(folder.lstrip('888'))

            if patient_id == 9:
                logging.info(
                    'WARNING: Skipping case 9, because one annotation has wrong dimensions...'
                )
                continue

            if patient_id % 5 == 0:
                train_test = 'test'
            elif patient_id % 4 == 0:
                train_test = 'validation'
            else:
                train_test = 'train'

            file_path = os.path.join(folder_path, 't2_tse_tra.nii.gz')

            annotator_mask_list = []
            for exp in expert_list:
                mask_folder = os.path.join(input_mask_folder, exp)
                file = glob.glob(
                    os.path.join(mask_folder,
                                 '*' + str(patient_id).zfill(4) + '_*.nii.gz'))
                # for ii in range(len(file)):
                #     if 'NCI' in file[ii]:
                #         del file[ii]
                assert len(
                    file
                ) == 1, 'more or less than one file matches the glob pattern %s' % (
                    '*' + str(patient_id).zfill(5) + '*.nii.gz')
                annotator_mask_list.append(file[0])

            mask_file_list[train_test].append(annotator_mask_list)
            image_file_list[train_test].append(file_path)

            patient_id_list[train_test].append(patient_id)

            nifty_img = nib.load(file_path)
            num_slices[train_test] += nifty_img.shape[2]

    # Write the small datasets
    for tt in ['test', 'train', 'validation']:
        hdf5_file.create_dataset('patient_id_%s' % tt,
                                 data=np.asarray(patient_id_list[tt],
                                                 dtype=np.uint8))

    nx, ny = size
    n_test = num_slices['test']
    n_train = num_slices['train']
    n_val = num_slices['validation']

    print('Debug: Check if sets add up to correct value:')
    print(n_train, n_val, n_test, n_train + n_val + n_test)

    # Create datasets for images and masks
    data = {}
    for tt, num_points in zip(['test', 'train', 'validation'],
                              [n_test, n_train, n_val]):

        if num_points > 0:
            data['images_%s' % tt] = hdf5_file.create_dataset(
                "images_%s" % tt, [num_points] + list(size), dtype=np.float32)
            data['masks_%s' % tt] = hdf5_file.create_dataset(
                "masks_%s" % tt, [num_points] + list(size) + [num_annotators],
                dtype=np.uint8)

    mask_list = {'test': [], 'train': [], 'validation': []}
    img_list = {'test': [], 'train': [], 'validation': []}

    logging.info('Parsing image files')

    for train_test in ['test', 'train', 'validation']:

        write_buffer = 0
        counter_from = 0

        patient_counter = 0
        for img_file, mask_files in zip(image_file_list[train_test],
                                        mask_file_list[train_test]):

            patient_counter += 1

            logging.info(
                '-----------------------------------------------------------')
            logging.info('Doing: %s' % img_file)

            img_dat = utils.load_nii(img_file)
            img = img_dat[0]

            masks = []
            for mf in mask_files:
                mask_dat = utils.load_nii(mf)
                masks.append(mask_dat[0])
            masks_arr = np.asarray(masks)  # annotator, size_x, size_y, size_z
            masks_arr = masks_arr.transpose(
                (1, 2, 3, 0))  # size_x, size_y, size_z, annotator

            img = utils.normalise_image(img)

            pixel_size = (img_dat[2].structarr['pixdim'][1],
                          img_dat[2].structarr['pixdim'][2],
                          img_dat[2].structarr['pixdim'][3])

            logging.info('Pixel size:')
            logging.info(pixel_size)

            scale_vector = [
                pixel_size[0] / target_resolution[0],
                pixel_size[1] / target_resolution[1]
            ]

            for zz in range(img.shape[2]):

                slice_img = np.squeeze(img[:, :, zz])
                slice_rescaled = transform.rescale(slice_img,
                                                   scale_vector,
                                                   order=1,
                                                   preserve_range=True,
                                                   multichannel=False,
                                                   mode='constant')

                slice_mask = np.squeeze(masks_arr[:, :, zz, :])
                mask_rescaled = transform.rescale(slice_mask,
                                                  scale_vector,
                                                  order=0,
                                                  preserve_range=True,
                                                  multichannel=True,
                                                  mode='constant')

                slice_cropped = crop_or_pad_slice_to_size(
                    slice_rescaled, nx, ny)
                mask_cropped = crop_or_pad_slice_to_size(mask_rescaled, nx, ny)

                # REMOVE SEMINAL VESICLES
                mask_cropped[mask_cropped == 3] = 0

                # DEBUG
                # import matplotlib.pyplot as plt
                # plt.figure()
                # plt.imshow(slice_img)
                #
                # plt.figure()
                # plt.imshow(slice_rescaled)
                #
                # plt.figure()
                # plt.imshow(slice_cropped)
                #
                # plt.show()
                # END DEBUG

                img_list[train_test].append(slice_cropped)
                mask_list[train_test].append(mask_cropped)

                write_buffer += 1

                # Writing needs to happen inside the loop over the slices
                if write_buffer >= MAX_WRITE_BUFFER:
                    counter_to = counter_from + write_buffer
                    _write_range_to_hdf5(data, train_test, img_list, mask_list,
                                         counter_from, counter_to)
                    _release_tmp_memory(img_list, mask_list, train_test)

                    # reset stuff for next iteration
                    counter_from = counter_to
                    write_buffer = 0

        # after file loop: Write the remaining data

        logging.info('Writing remaining data')
        counter_to = counter_from + write_buffer

        _write_range_to_hdf5(data, train_test, img_list, mask_list,
                             counter_from, counter_to)
        _release_tmp_memory(img_list, mask_list, train_test)

    # After test train loop:
    hdf5_file.close()
Пример #15
0
def prepare_data(input_folder, preproc_folder, idx_start, idx_end):

    images = []
    affines = []
    patnames = []
    masks = []

    # read the filenames which have segmentations available
    filenames = sorted(glob.glob(input_folder + '*_seg.nii'))
    logging.info(
        'Number of images in the dataset that have ground truth annotations: %s'
        % str(len(filenames)))

    # iterate through all indices
    for idx in range(len(filenames)):

        # only consider images within the indices requested
        if (idx < idx_start) or (idx >= idx_end):
            #logging.info('skipping subject: %d' %idx)
            continue

        logging.info('==============================================')

        # get the name of the ground truth annotation for this subject
        filename_seg = filenames[idx]
        filename_img = filename_seg[:-8] + '.nii.gz'
        _patname = filename_seg[filename_seg[:-1].rfind('/') + 1:-8]

        if _patname == 'IXI014-HH-1236-T2':  # this subject has very poor resolution - 256x256x28
            continue

        # read the image
        logging.info('reading image: %s' % _patname)
        _img_data, _img_affine, _img_header = utils.load_nii(filename_img)
        # make all the images of the same size by appending zero slices to facilitate appending
        # most images are of the size 256*256*130
        if (_img_data.shape[2] is not 130):
            num_zero_slices = 130 - _img_data.shape[2]
            zero_slices = np.zeros(
                (_img_data.shape[0], _img_data.shape[1], num_zero_slices))
            _img_data = np.concatenate((_img_data, zero_slices), axis=-1)
        # normalise the image
        _img_data = image_utils.normalise_image(_img_data,
                                                norm_type='div_by_max')
        # save the pre-processed image
        utils.makefolder(preproc_folder + _patname)
        savepath = preproc_folder + _patname + '/preprocessed_image.nii'
        utils.save_nii(savepath, _img_data, _img_affine)
        # append to the list of all images, affines and patient names
        images.append(_img_data)
        affines.append(_img_affine)
        patnames.append(_patname)

        # read the segmentation mask (already grouped)
        _seg_data, _seg_affine, _seg_header = utils.load_nii(filename_seg)
        # make all the images of the same size by appending zero slices to facilitate appending
        # most images are of the size 256*256*130
        if (_seg_data.shape[2] is not 130):
            num_zero_slices = 130 - _seg_data.shape[2]
            zero_slices = np.zeros(
                (_seg_data.shape[0], _seg_data.shape[1], num_zero_slices))
            _seg_data = np.concatenate((_seg_data, zero_slices), axis=-1)
        # save the pre-processed segmentation ground truth
        utils.makefolder(preproc_folder + _patname)
        savepath = preproc_folder + _patname + '/preprocessed_gt15.nii'
        utils.save_nii(savepath, _seg_data, _seg_affine)
        # append to the list of all masks
        masks.append(_seg_data)

    # convert the lists to arrays
    images = np.array(images)
    affines = np.array(affines)
    patnames = np.array(patnames)
    masks = np.array(masks, dtype='uint8')

    # merge along the y-zis to get a stack of x-z slices, for the images as well as the masks
    images = images.swapaxes(1, 2)
    images = images.reshape(-1, images.shape[2], images.shape[3])
    masks = masks.swapaxes(1, 2)
    masks = masks.reshape(-1, masks.shape[2], masks.shape[3])

    # save the processed images and masks so that they can be directly read the next time
    # make appropriate filenames according to the requested indices of training, validation and test images
    logging.info('Saving pre-processed files...')
    config_details = 'from%dto%d_' % (idx_start, idx_end)

    filepath_images = preproc_folder + config_details + 'images.npy'
    filepath_masks = preproc_folder + config_details + 'annotations15.npy'
    filepath_affine = preproc_folder + config_details + 'affines.npy'
    filepath_patnames = preproc_folder + config_details + 'patnames.npy'

    np.save(filepath_images, images)
    np.save(filepath_masks, masks)
    np.save(filepath_affine, affines)
    np.save(filepath_patnames, patnames)

    return images, masks, affines, patnames
Пример #16
0
def prepare_data(input_folder, output_file, size, input_channels,
                 target_resolution):
    '''
    Main function that prepares a dataset from the raw challenge data to an hdf5 dataset
    '''

    if len(size) != 3:
        raise AssertionError('Inadequate number of size parameters')
    if len(target_resolution) != 3:
        raise AssertionError(
            'Inadequate number of target resolution parameters')

    hdf5_file = h5py.File(output_file, "w")

    file_list = {'test': [], 'train': [], 'validation': []}

    logging.info('Counting files and parsing meta data...')

    pid = 0
    for folder in os.listdir(input_folder):
        print(folder)
        train_test = test_train_val_split(pid)
        pid = pid + 1
        file_list[train_test].append(folder)

    n_train = len(file_list['train'])
    n_test = len(file_list['test'])
    n_val = len(file_list['validation'])

    print('Debug: Check if sets add up to correct value:')
    print(n_train, n_val, n_test, n_train + n_val + n_test)

    # Create datasets for images and masks
    data = {}
    for tt, num_points in zip(['test', 'train', 'validation'],
                              [n_test, n_train, n_val]):

        if num_points > 0:
            print([num_points] + list(size) + [input_channels])
            data['images_%s' % tt] = hdf5_file.create_dataset(
                "images_%s" % tt, [num_points] + list(size) + [input_channels],
                dtype=np.float32)
            data['masks_%s' % tt] = hdf5_file.create_dataset(
                "masks_%s" % tt, [num_points] + list(size), dtype=np.uint8)
            data['pids_%s' % tt] = hdf5_file.create_dataset(
                "pids_%s" % tt, [num_points],
                dtype=h5py.special_dtype(vlen=str))

    mask_list = {'test': [], 'train': [], 'validation': []}
    img_list = {'test': [], 'train': [], 'validation': []}
    pids_list = {'test': [], 'train': [], 'validation': []}

    logging.info('Parsing image files')

    #get max dimension in z-axis
    maxX = 0
    maxY = 0
    maxZ = 0
    maxXCropped = 0
    maxYCropped = 0
    maxZCropped = 0
    i = 0
    for train_test in ['test', 'train', 'validation']:
        for folder in file_list[train_test]:
            print("Doing file {}".format(i))
            i += 1

            baseFilePath = os.path.join(input_folder, folder, folder)
            img_c1, _, img_header = utils.load_nii(baseFilePath + "_t1.nii.gz")
            img_c2, _, _ = utils.load_nii(baseFilePath + "_t1ce.nii.gz")
            img_c3, _, _ = utils.load_nii(baseFilePath + "_t2.nii.gz")
            img_c4, _, _ = utils.load_nii(baseFilePath + "_flair.nii.gz")
            img_dat = np.stack((img_c1, img_c2, img_c3, img_c4), 3)

            maxX = max(maxX, img_dat.shape[0])
            maxY = max(maxY, img_dat.shape[1])
            maxZ = max(maxZ, img_dat.shape[2])
            img_dat_cropped = crop_volume_allDim(img_dat)
            maxXCropped = max(maxXCropped, img_dat_cropped.shape[0])
            maxYCropped = max(maxYCropped, img_dat_cropped.shape[1])
            maxZCropped = max(maxZCropped, img_dat_cropped.shape[2])
    print("Max x: {}, y: {}, z: {}".format(maxX, maxY, maxZ))
    print("Max cropped x: {}, y: {}, z: {}".format(maxXCropped, maxYCropped,
                                                   maxZCropped))

    for train_test in ['train', 'test', 'validation']:

        write_buffer = 0
        counter_from = 0

        for folder in file_list[train_test]:

            logging.info(
                '-----------------------------------------------------------')
            logging.info('Doing: %s' % folder)

            patient_id = folder

            baseFilePath = os.path.join(input_folder, folder, folder)
            img_c1, _, img_header = utils.load_nii(baseFilePath + "_t1.nii.gz")
            img_c2, _, _ = utils.load_nii(baseFilePath + "_t1ce.nii.gz")
            img_c3, _, _ = utils.load_nii(baseFilePath + "_t2.nii.gz")
            img_c4, _, _ = utils.load_nii(baseFilePath + "_flair.nii.gz")
            mask_dat, _, _ = utils.load_nii(baseFilePath + "_seg.nii.gz")

            img_dat = np.stack((img_c1, img_c2, img_c3, img_c4), 3)

            img, mask = crop_volume_allDim(img_dat.copy(), mask_dat.copy())

            pixel_size = (img_header.structarr['pixdim'][1],
                          img_header.structarr['pixdim'][2],
                          img_header.structarr['pixdim'][3])

            logging.info('Pixel size:')
            logging.info(pixel_size)

            ### PROCESSING LOOP FOR 3D DATA ################################

            scale_vector = [
                pixel_size[0] / target_resolution[0],
                pixel_size[1] / target_resolution[1],
                pixel_size[2] / target_resolution[2]
            ]

            if scale_vector != [1.0, 1.0, 1.0]:
                img = transform.rescale(img,
                                        scale_vector,
                                        order=1,
                                        preserve_range=True,
                                        multichannel=True,
                                        mode='constant')
                mask = transform.rescale(mask,
                                         scale_vector,
                                         order=0,
                                         preserve_range=True,
                                         multichannel=False,
                                         mode='constant')

            img = crop_or_pad_slice_to_size(img, size, input_channels)
            mask = crop_or_pad_slice_to_size(mask, size)

            img = normalise_image(img)

            img_list[train_test].append(img)
            mask_list[train_test].append(mask)
            pids_list[train_test].append(patient_id)

            write_buffer += 1

            if write_buffer >= MAX_WRITE_BUFFER:

                counter_to = counter_from + write_buffer
                _write_range_to_hdf5(data, train_test, img_list, mask_list,
                                     pids_list, counter_from, counter_to)
                _release_tmp_memory(img_list, mask_list, pids_list, train_test)

                # reset stuff for next iteration
                counter_from = counter_to
                write_buffer = 0

        logging.info('Writing remaining data')
        counter_to = counter_from + write_buffer

        if len(file_list[train_test]) > 0:
            _write_range_to_hdf5(data, train_test, img_list, mask_list,
                                 pids_list, counter_from, counter_to)
        _release_tmp_memory(img_list, mask_list, pids_list, train_test)

    # After test train loop:
    hdf5_file.close()
Пример #17
0
def prepare_data(input_folder, output_filepath, idx_start, idx_end, size,
                 target_resolution, labeller):

    # ===============================
    # create a hdf5 file
    # ===============================
    hdf5_file = h5py.File(output_filepath, "w")

    # ===============================
    # read all the patient folders from the base input folder
    # ===============================
    folder_list = []
    for folder in os.listdir(input_folder):
        folder_path = os.path.join(input_folder, folder)
        if os.path.isdir(folder_path) and 't2_tse_tra.nii.gz' in os.listdir(
                folder_path):
            if 'segmentation_' + labeller + '.nii.gz' in os.listdir(
                    folder_path
            ) or 'segmentation_tra_' + labeller + '.nii.gz' in os.listdir(
                    folder_path):
                folder_list.append(folder_path)

    # ===============================
    # Create datasets for images and labels
    # ===============================
    data = {}
    num_slices = count_slices(folder_list, idx_start, idx_end)
    data['images'] = hdf5_file.create_dataset("images",
                                              [num_slices] + list(size),
                                              dtype=np.float32)
    data['labels'] = hdf5_file.create_dataset("labels",
                                              [num_slices] + list(size),
                                              dtype=np.float32)

    # ===============================
    # initialize lists
    # ===============================
    label_list = []
    image_list = []
    nx_list = []
    ny_list = []
    nz_list = []
    px_list = []
    py_list = []
    pz_list = []
    pat_names_list = []

    # ===============================
    # ===============================
    write_buffer = 0
    counter_from = 0
    patient_counter = 0

    # ===============================
    # iterate through the requested indices
    # ===============================
    for idx in range(idx_start, idx_end):

        patient_counter = patient_counter + 1

        # ==================
        # read the image file
        # ==================
        image, _, image_hdr = utils.load_nii(folder_list[idx] +
                                             '/t2_tse_tra_n4.nii.gz')

        # ============
        # normalize the image to be between 0 and 1
        # ============
        image_normalized = utils.normalise_image(image, norm_type='div_by_max')

        # ==================
        # collect some header info.
        # ==================
        px_list.append(float(image_hdr.get_zooms()[0]))
        py_list.append(float(image_hdr.get_zooms()[1]))
        pz_list.append(float(image_hdr.get_zooms()[2]))
        nx_list.append(image.shape[0])
        ny_list.append(image.shape[1])
        nz_list.append(image.shape[2])
        pat_names_list.append(folder_list[idx][folder_list[idx].rfind('/') +
                                               1:])

        # ==================
        # read the label file
        # ==================
        if 'segmentation_' + labeller + '.nii.gz' in os.listdir(
                folder_list[idx]):
            label, _, _ = utils.load_nii(folder_list[idx] + '/segmentation_' +
                                         labeller + '.nii.gz')
        elif 'segmentation_tra_' + labeller + '.nii.gz' in os.listdir(
                folder_list[idx]):
            label, _, _ = utils.load_nii(folder_list[idx] +
                                         '/segmentation_tra_' + labeller +
                                         '.nii.gz')

        # ==================
        # remove extra label from some images
        # ==================
        label[label > 2] = 0
        print(np.unique(label))

        # ======================================================
        ### PROCESSING LOOP FOR SLICE-BY-SLICE 2D DATA ###################
        # ======================================================
        scale_vector = [
            image_hdr.get_zooms()[0] / target_resolution[0],
            image_hdr.get_zooms()[1] / target_resolution[1]
        ]

        for zz in range(image.shape[2]):

            # ============
            # rescale the images and labels so that their orientation matches that of the nci dataset
            # ============
            image2d_rescaled = rescale(np.squeeze(image_normalized[:, :, zz]),
                                       scale_vector,
                                       order=1,
                                       preserve_range=True,
                                       multichannel=False,
                                       mode='constant')

            label2d_rescaled = rescale(np.squeeze(label[:, :, zz]),
                                       scale_vector,
                                       order=0,
                                       preserve_range=True,
                                       multichannel=False,
                                       mode='constant')

            # ============
            # rotate the images and labels so that their orientation matches that of the nci dataset
            # ============
            image2d_rescaled_rotated = np.rot90(image2d_rescaled, k=3)
            label2d_rescaled_rotated = np.rot90(label2d_rescaled, k=3)

            # ============
            # crop or pad to make of the same size
            # ============
            image2d_rescaled_rotated_cropped = crop_or_pad_slice_to_size(
                image2d_rescaled_rotated, size[0], size[1])
            label2d_rescaled_rotated_cropped = crop_or_pad_slice_to_size(
                label2d_rescaled_rotated, size[0], size[1])

            image_list.append(image2d_rescaled_rotated_cropped)
            label_list.append(label2d_rescaled_rotated_cropped)

            write_buffer += 1

            # Writing needs to happen inside the loop over the slices
            if write_buffer >= MAX_WRITE_BUFFER:

                counter_to = counter_from + write_buffer

                _write_range_to_hdf5(data, image_list, label_list,
                                     counter_from, counter_to)

                _release_tmp_memory(image_list, label_list)

                # update counters
                counter_from = counter_to
                write_buffer = 0

    logging.info('Writing remaining data')
    counter_to = counter_from + write_buffer
    _write_range_to_hdf5(data, image_list, label_list, counter_from,
                         counter_to)
    _release_tmp_memory(image_list, label_list)

    # Write the small datasets
    hdf5_file.create_dataset('nx', data=np.asarray(nx_list, dtype=np.uint16))
    hdf5_file.create_dataset('ny', data=np.asarray(ny_list, dtype=np.uint16))
    hdf5_file.create_dataset('nz', data=np.asarray(nz_list, dtype=np.uint16))
    hdf5_file.create_dataset('px', data=np.asarray(px_list, dtype=np.float32))
    hdf5_file.create_dataset('py', data=np.asarray(py_list, dtype=np.float32))
    hdf5_file.create_dataset('pz', data=np.asarray(pz_list, dtype=np.float32))
    hdf5_file.create_dataset('patnames',
                             data=np.asarray(pat_names_list, dtype="S10"))

    # After test train loop:
    hdf5_file.close()
Пример #18
0
def prepare_data(input_folder, output_file, idx_start, idx_end, protocol, size,
                 depth, target_resolution, preprocessing_folder):

    # ========================
    # read the filenames
    # ========================
    folders_list = sorted(glob.glob(input_folder + '/*/'))
    logging.info('Number of images in the dataset: %s' %
                 str(len(folders_list)))

    # =======================
    # create a hdf5 file
    # =======================
    hdf5_file = h5py.File(output_file, "w")

    # ===============================
    # Create datasets for images and labels
    # ===============================
    data = {}
    num_slices = count_slices(folders_list, idx_start, idx_end, depth)

    data['images'] = hdf5_file.create_dataset("images",
                                              [num_slices] + list(size),
                                              dtype=np.float32)
    data['labels'] = hdf5_file.create_dataset("labels",
                                              [num_slices] + list(size),
                                              dtype=np.uint8)

    # ===============================
    # initialize lists
    # ===============================
    label_list = []
    image_list = []
    nx_list = []
    ny_list = []
    nz_list = []
    px_list = []
    py_list = []
    pz_list = []
    pat_names_list = []

    # ===============================
    # ===============================
    write_buffer = 0
    counter_from = 0

    # ===============================
    # iterate through the requested indices
    # ===============================
    for idx in range(idx_start, idx_end):

        # ==================
        # get file paths
        # ==================
        patient_name, image_path, label_path = get_image_and_label_paths(
            folders_list[idx])

        # ============
        # read the image and normalize it to be between 0 and 1
        # ============
        image, _, image_hdr = utils.load_nii(image_path)
        image = np.swapaxes(
            image, 1, 2
        )  # swap axes 1 and 2 -> this allows appending along axis 2, as in other datasets

        # ==================
        # read the label file
        # ==================
        label, _, _ = utils.load_nii(label_path)
        label = np.swapaxes(
            label, 1, 2
        )  # swap axes 1 and 2 -> this allows appending along axis 2, as in other datasets
        # labels have already been grouped as required

        # ============
        # create a segmentation mask and use it to get rid of the skull in the image
        # ============
        label_mask = np.copy(label)
        label_mask[label > 0] = 1
        image = image * label_mask

        # ==================
        # crop out some portion of the image, which are all zeros (rough registration via visual inspection)
        # ==================
        image, label = center_image_and_label(image, label)

        # plt.figure(); plt.imshow(image[:,:,50], cmap='gray'); plt.title(patient_name); plt.show(); plt.close()

        # ==================
        # crop volume along z axis (as there are several zeros towards the ends)
        # ==================
        image = utils.crop_or_pad_volume_to_size_along_z(image, depth)
        label = utils.crop_or_pad_volume_to_size_along_z(label, depth)

        # ==================
        # collect some header info.
        # ==================
        px_list.append(float(image_hdr.get_zooms()[0]))
        py_list.append(
            float(image_hdr.get_zooms()[2])
        )  # since axes 1 and 2 have been swapped. this is important when dealing with pixel dimensions
        pz_list.append(float(image_hdr.get_zooms()[1]))
        nx_list.append(image.shape[0])
        ny_list.append(
            image.shape[1]
        )  # since axes 1 and 2 have been swapped. however, only the final axis locations are relevant when dealing with shapes
        nz_list.append(image.shape[2])
        pat_names_list.append(patient_name)

        # ==================
        # normalize the image
        # ==================
        image_normalized = utils.normalise_image(image, norm_type='div_by_max')

        # ======================================================
        ### PROCESSING LOOP FOR SLICE-BY-SLICE 2D DATA ###################
        # ======================================================
        scale_vector = [
            image_hdr.get_zooms()[0] / target_resolution[0],
            image_hdr.get_zooms()[2] / target_resolution[1]
        ]  # since axes 1 and 2 have been swapped. this is important when dealing with pixel dimensions

        for zz in range(image.shape[2]):

            # ============
            # rescale the images and labels so that their orientation matches that of the nci dataset
            # ============
            image2d_rescaled = rescale(np.squeeze(image_normalized[:, :, zz]),
                                       scale_vector,
                                       order=1,
                                       preserve_range=True,
                                       multichannel=False,
                                       mode='constant')

            label2d_rescaled = rescale(np.squeeze(label[:, :, zz]),
                                       scale_vector,
                                       order=0,
                                       preserve_range=True,
                                       multichannel=False,
                                       mode='constant')

            # ============
            # rotate to align with other datasets
            # ============
            image2d_rescaled_rotated = np.rot90(image2d_rescaled, k=0)
            label2d_rescaled_rotated = np.rot90(label2d_rescaled, k=0)

            # ============
            # crop or pad to make of the same size
            # ============
            image2d_rescaled_rotated_cropped = utils.crop_or_pad_slice_to_size(
                image2d_rescaled_rotated, size[0], size[1])
            label2d_rescaled_rotated_cropped = utils.crop_or_pad_slice_to_size(
                label2d_rescaled_rotated, size[0], size[1])

            # ============
            # append to list
            # ============
            image_list.append(image2d_rescaled_rotated_cropped)
            label_list.append(label2d_rescaled_rotated_cropped)

            write_buffer += 1

            # Writing needs to happen inside the loop over the slices
            if write_buffer >= MAX_WRITE_BUFFER:

                counter_to = counter_from + write_buffer

                _write_range_to_hdf5(data, image_list, label_list,
                                     counter_from, counter_to)

                _release_tmp_memory(image_list, label_list)

                # update counters
                counter_from = counter_to
                write_buffer = 0

    logging.info('Writing remaining data')
    counter_to = counter_from + write_buffer
    _write_range_to_hdf5(data, image_list, label_list, counter_from,
                         counter_to)
    _release_tmp_memory(image_list, label_list)

    # Write the small datasets
    hdf5_file.create_dataset('nx', data=np.asarray(nx_list, dtype=np.uint16))
    hdf5_file.create_dataset('ny', data=np.asarray(ny_list, dtype=np.uint16))
    hdf5_file.create_dataset('nz', data=np.asarray(nz_list, dtype=np.uint16))
    hdf5_file.create_dataset('px', data=np.asarray(px_list, dtype=np.float32))
    hdf5_file.create_dataset('py', data=np.asarray(py_list, dtype=np.float32))
    hdf5_file.create_dataset('pz', data=np.asarray(pz_list, dtype=np.float32))
    hdf5_file.create_dataset('patnames',
                             data=np.asarray(pat_names_list, dtype="S10"))

    # After test train loop:
    hdf5_file.close()
Пример #19
0
def prepare_data(input_folder,
                 output_file,
                 idx_start,
                 idx_end,
                 protocol,
                 size,
                 depth,
                 target_resolution,
                 preprocessing_folder):

    # ========================    
    # read the filenames
    # ========================
    filenames = sorted(glob.glob(input_folder + '*.zip'))
    logging.info('Number of images in the dataset: %s' % str(len(filenames)))

    # =======================
    # create a new hdf5 file
    # =======================
    hdf5_file = h5py.File(output_file, "w")

    # ===============================
    # Create datasets for images and labels
    # ===============================
    data = {}
    num_slices = count_slices(filenames,
                              idx_start,
                              idx_end,
                              protocol,
                              preprocessing_folder,
                              depth)
    
    # ===============================
    # the sizes of the image and label arrays are set as: [(number of coronal slices per subject*number of subjects), size of coronal slices]
    # ===============================
    data['images'] = hdf5_file.create_dataset("images", [num_slices] + list(size), dtype=np.float32)
    data['labels'] = hdf5_file.create_dataset("labels", [num_slices] + list(size), dtype=np.uint8)
    
    # ===============================
    # initialize lists
    # ===============================        
    label_list = []
    image_list = []
    nx_list = []
    ny_list = []
    nz_list = []
    px_list = []
    py_list = []
    pz_list = []
    pat_names_list = []
    
    # ===============================      
    # initialize counters
    # ===============================        
    write_buffer = 0
    counter_from = 0
    
    # ===============================
    # iterate through the requested indices
    # ===============================
    for idx in range(idx_start, idx_end):
        
        # ==================
        # get file paths
        # ==================
        patient_name, image_path, label_path = get_image_and_label_paths(filenames[idx],
                                                                         protocol,
                                                                         preprocessing_folder)
        
        # ============
        # read the image and normalize it to be between 0 and 1
        # ============
        image, _, image_hdr = utils.load_nii(image_path)
        image = np.swapaxes(image, 1, 2) # swap axes 1 and 2 -> this allows appending along axis 2, as in other datasets
        
        # ==================
        # read the label file
        # ==================        
        label, _, _ = utils.load_nii(label_path)        
        label = np.swapaxes(label, 1, 2) # swap axes 1 and 2 -> this allows appending along axis 2, as in other datasets
        label = utils.group_segmentation_classes(label) # group the segmentation classes as required
                
        # ==================
        # crop volume along z axis (as there are several zeros towards the ends)
        # ==================
        image = utils.crop_or_pad_volume_to_size_along_z(image, depth)
        label = utils.crop_or_pad_volume_to_size_along_z(label, depth)     

        # ==================
        # collect some header info.
        # ==================
        px_list.append(float(image_hdr.get_zooms()[0]))
        py_list.append(float(image_hdr.get_zooms()[2])) # since axes 1 and 2 have been swapped
        pz_list.append(float(image_hdr.get_zooms()[1]))
        nx_list.append(image.shape[0]) 
        ny_list.append(image.shape[1]) # since axes 1 and 2 have been swapped
        nz_list.append(image.shape[2])
        pat_names_list.append(patient_name)
        
        # ==================
        # normalize the image
        # ==================
        image_normalized = utils.normalise_image(image, norm_type='div_by_max')
                        
        # ======================================================  
        ### PROCESSING LOOP FOR SLICE-BY-SLICE 2D DATA ###################
        # ======================================================
        scale_vector = [image_hdr.get_zooms()[0] / target_resolution[0],
                        image_hdr.get_zooms()[2] / target_resolution[1]] # since axes 1 and 2 have been swapped

        for zz in range(image.shape[2]):

            # ============
            # rescale the images and labels so that their orientation matches that of the nci dataset
            # ============            
            image2d_rescaled = rescale(np.squeeze(image_normalized[:, :, zz]),
                                                  scale_vector,
                                                  order=1,
                                                  preserve_range=True,
                                                  multichannel=False,
                                                  mode = 'constant')
 
            label2d_rescaled = rescale(np.squeeze(label[:, :, zz]),
                                                  scale_vector,
                                                  order=0,
                                                  preserve_range=True,
                                                  multichannel=False,
                                                  mode='constant')
            
            # ============            
            # crop or pad to make of the same size
            # ============            
            image2d_rescaled_rotated_cropped = utils.crop_or_pad_slice_to_size(image2d_rescaled, size[0], size[1])
            label2d_rescaled_rotated_cropped = utils.crop_or_pad_slice_to_size(label2d_rescaled, size[0], size[1])

            # ============   
            # append to list
            # ============   
            image_list.append(image2d_rescaled_rotated_cropped)
            label_list.append(label2d_rescaled_rotated_cropped)

            # ============   
            # increment counter
            # ============   
            write_buffer += 1

            # ============   
            # Writing needs to happen inside the loop over the slices
            # ============   
            if write_buffer >= MAX_WRITE_BUFFER:

                counter_to = counter_from + write_buffer

                _write_range_to_hdf5(data,
                                     image_list,
                                     label_list,
                                     counter_from,
                                     counter_to)
                
                _release_tmp_memory(image_list,
                                    label_list)

                # ============   
                # update counters 
                # ============   
                counter_from = counter_to
                write_buffer = 0
        
    # ============   
    # write leftover data
    # ============   
    logging.info('Writing remaining data')
    counter_to = counter_from + write_buffer
    _write_range_to_hdf5(data,
                         image_list,
                         label_list,
                         counter_from,
                         counter_to)
    _release_tmp_memory(image_list,
                        label_list)

    # ============   
    # Write the small datasets - image resolutions, sizes, patient ids
    # ============   
    hdf5_file.create_dataset('nx', data=np.asarray(nx_list, dtype=np.uint16))
    hdf5_file.create_dataset('ny', data=np.asarray(ny_list, dtype=np.uint16))
    hdf5_file.create_dataset('nz', data=np.asarray(nz_list, dtype=np.uint16))
    hdf5_file.create_dataset('px', data=np.asarray(px_list, dtype=np.float32))
    hdf5_file.create_dataset('py', data=np.asarray(py_list, dtype=np.float32))
    hdf5_file.create_dataset('pz', data=np.asarray(pz_list, dtype=np.float32))
    hdf5_file.create_dataset('patnames', data=np.asarray(pat_names_list, dtype="S10"))
    
    # ============   
    # close the hdf5 file
    # ============   
    hdf5_file.close()
Пример #20
0
    def __init__(
        self,
        images_dir,
        transform=None,
        image_size=256,
        subset="train",
        random_sampling=True,
        validation_cases=0,
        seed=42,
    ):
        assert subset in ["all", "train", "validation"]

        # read images
        volumes = {}
        masks = {}
        print("reading {} images...".format(subset))
        #dirpath 是当前目录, dirnames,是目录下的文件夹,filenames, 是目录下的文件
        for (dirpath, dirnames, filenames) in os.walk(images_dir):
            image_slices = []
            mask_slices = []
            mask_path = ""
            #filter 来筛选名字带.tif的文件
            #key指按照某一项排序
            #filter 来筛选名字带.tif的文件
            #key指按照某一项排序
            for filename in sorted(filter(lambda f: ".gz" in f, filenames)):  
                if "seg" in filename:
                    mask_path = os.path.join(dirpath,filename)
                    mask_slices.append(load_nii(mask_path))
                else:
                    filepath = os.path.join(dirpath, filename) 
                    image_slices.append(load_nii(filepath))

            embed()
            #只筛选带有肿瘤的slice
            if len(image_slices) > 0:
                patient_id = dirpath.split("/")[-1]

                volumes[patient_id] = np.array(image_slices).transpose(1,2,3,0)
                masks[patient_id] = np.array(mask_slices).transpose(1,2,3,0)

            embed()

        #patient 是一个字典,里面是patient_id和其对应的image(无mask)
        self.patients = sorted(volumes)

        # select cases to subset
        if not subset == "all":
            random.seed(seed)
            #分出validation set
            validation_patients = random.sample(self.patients, k=validation_cases)                                      #注意K有可能超
            if subset == "validation":
                self.patients = validation_patients
            else:
                self.patients = sorted(
                    list(set(self.patients).difference(validation_patients))
                )

        print("preprocessing {} volumes...".format(subset))
        # create list of tuples (volume, mask)
        self.volumes = [(volumes[k], masks[k]) for k in self.patients]
        embed()

        # probabilities for sampling slices based on masks
        self.slice_weights = [m.sum(axis=-1).sum(axis=-1).sum(axis=-1) for v, m in self.volumes]
        self.slice_weights = [(s + (s.sum() * 0.1 / len(s))) / (s.sum() * 1.1) for s in self.slice_weights]
        
        print("one hotting {} masks...".format(subset))
        self.volumes = [(v, make_one_hot(m)) for v,  m in self.volumes]
        embed()

        print("resizing {} volumes...".format(subset))
        # resize
        self.volumes = [resize_sample(v, size=image_size) for v in self.volumes]
        embed()

        print("normalizing {} volumes...".format(subset))
        # normalize channel-wise
        self.volumes = [(normalize_volume(v), m) for v,  m in self.volumes]
        embed()

        print("one hotting {} masks...".format(subset))
        self.volumes = [(v, convert_mask_to_one(m)) for v,  m in self.volumes]
        embed()

        print("done creating {} dataset".format(subset))

        # create global index for patient and slice (idx -> (p_idx, s_idx))
        num_slices = [v.shape[0] for v, m in self.volumes]
        self.patient_slice_index = list(
            zip(
                sum([[i] * num_slices[i] for i in range(len(num_slices))], []),
                sum([list(range(x)) for x in num_slices], []),
            )
        )

        self.random_sampling = random_sampling
        self.transform = transform
        embed()
def score_data(input_folder,
               output_folder,
               model_path,
               exp_config,
               do_postprocessing=False,
               gt_exists=True,
               evaluate_all=False,
               use_iter=None):

    nx, ny = exp_config.image_size[:2]
    batch_size = 1
    num_channels = exp_config.nlabels

    image_tensor_shape = [batch_size] + list(exp_config.image_size) + [1]
    images_pl = tf.placeholder(tf.float32,
                               shape=image_tensor_shape,
                               name='images')

    mask_pl, softmax_pl = model.predict(images_pl, exp_config)
    saver = tf.train.Saver()
    init = tf.global_variables_initializer()

    evaluate_test_set = not gt_exists

    with tf.Session() as sess:

        sess.run(init)

        if not use_iter:
            checkpoint_path = utils.get_latest_model_checkpoint_path(
                model_path, 'model_best_dice.ckpt')
        else:
            checkpoint_path = os.path.join(model_path,
                                           'model.ckpt-%d' % use_iter)

        saver.restore(sess, checkpoint_path)

        init_iteration = int(checkpoint_path.split('/')[-1].split('-')[-1])

        total_time = 0
        total_volumes = 0

        for folder in os.listdir(input_folder):

            folder_path = os.path.join(input_folder, folder)

            if os.path.isdir(folder_path):

                if evaluate_test_set or evaluate_all:
                    train_test = 'test'  # always test
                else:
                    train_test = 'test' if (int(folder[-3:]) %
                                            5 == 0) else 'train'

                if train_test == 'test':

                    infos = {}
                    for line in open(os.path.join(folder_path, 'Info.cfg')):
                        label, value = line.split(':')
                        infos[label] = value.rstrip('\n').lstrip(' ')

                    patient_id = folder.lstrip('patient')
                    ED_frame = int(infos['ED'])
                    ES_frame = int(infos['ES'])

                    for file in glob.glob(
                            os.path.join(folder_path,
                                         'patient???_frame??.nii.gz')):

                        logging.info(
                            ' ----- Doing image: -------------------------')
                        logging.info('Doing: %s' % file)
                        logging.info(
                            ' --------------------------------------------')

                        file_base = file.split('.nii.gz')[0]

                        frame = int(file_base.split('frame')[-1])
                        img_dat = utils.load_nii(file)
                        img = img_dat[0].copy()
                        img = image_utils.normalise_image(img)

                        if gt_exists:
                            file_mask = file_base + '_gt.nii.gz'
                            mask_dat = utils.load_nii(file_mask)
                            mask = mask_dat[0]

                        start_time = time.time()

                        if exp_config.data_mode == '2D':

                            pixel_size = (img_dat[2].structarr['pixdim'][1],
                                          img_dat[2].structarr['pixdim'][2])
                            scale_vector = (pixel_size[0] /
                                            exp_config.target_resolution[0],
                                            pixel_size[1] /
                                            exp_config.target_resolution[1])

                            predictions = []

                            for zz in range(img.shape[2]):

                                slice_img = np.squeeze(img[:, :, zz])
                                slice_rescaled = transform.rescale(
                                    slice_img,
                                    scale_vector,
                                    order=1,
                                    preserve_range=True,
                                    multichannel=False,
                                    mode='constant')

                                x, y = slice_rescaled.shape

                                x_s = (x - nx) // 2
                                y_s = (y - ny) // 2
                                x_c = (nx - x) // 2
                                y_c = (ny - y) // 2

                                # Crop section of image for prediction
                                if x > nx and y > ny:
                                    slice_cropped = slice_rescaled[x_s:x_s +
                                                                   nx,
                                                                   y_s:y_s +
                                                                   ny]
                                else:
                                    slice_cropped = np.zeros((nx, ny))
                                    if x <= nx and y > ny:
                                        slice_cropped[
                                            x_c:x_c +
                                            x, :] = slice_rescaled[:, y_s:y_s +
                                                                   ny]
                                    elif x > nx and y <= ny:
                                        slice_cropped[:, y_c:y_c +
                                                      y] = slice_rescaled[
                                                          x_s:x_s + nx, :]
                                    else:
                                        slice_cropped[x_c:x_c + x, y_c:y_c +
                                                      y] = slice_rescaled[:, :]

                                # GET PREDICTION
                                network_input = np.float32(
                                    np.tile(
                                        np.reshape(slice_cropped, (nx, ny, 1)),
                                        (batch_size, 1, 1, 1)))
                                mask_out, logits_out = sess.run(
                                    [mask_pl, softmax_pl],
                                    feed_dict={images_pl: network_input})
                                prediction_cropped = np.squeeze(
                                    logits_out[0, ...])

                                # ASSEMBLE BACK THE SLICES
                                slice_predictions = np.zeros(
                                    (x, y, num_channels))
                                # insert cropped region into original image again
                                if x > nx and y > ny:
                                    slice_predictions[
                                        x_s:x_s + nx,
                                        y_s:y_s + ny, :] = prediction_cropped
                                else:
                                    if x <= nx and y > ny:
                                        slice_predictions[:, y_s:y_s +
                                                          ny, :] = prediction_cropped[
                                                              x_c:x_c +
                                                              x, :, :]
                                    elif x > nx and y <= ny:
                                        slice_predictions[
                                            x_s:x_s +
                                            nx, :, :] = prediction_cropped[:,
                                                                           y_c:
                                                                           y_c +
                                                                           y, :]
                                    else:
                                        slice_predictions[:, :, :] = prediction_cropped[
                                            x_c:x_c + x, y_c:y_c + y, :]

                                # RESCALING ON THE LOGITS
                                if gt_exists:
                                    prediction = transform.resize(
                                        slice_predictions,
                                        (mask.shape[0], mask.shape[1],
                                         num_channels),
                                        order=1,
                                        preserve_range=True,
                                        mode='constant')
                                else:  # This can occasionally lead to wrong volume size, therefore if gt_exists
                                    # we use the gt mask size for resizing.
                                    prediction = transform.rescale(
                                        slice_predictions,
                                        (1.0 / scale_vector[0],
                                         1.0 / scale_vector[1], 1),
                                        order=1,
                                        preserve_range=True,
                                        multichannel=False,
                                        mode='constant')

                                # prediction = transform.resize(slice_predictions,
                                #                               (mask.shape[0], mask.shape[1], num_channels),
                                #                               order=1,
                                #                               preserve_range=True,
                                #                               mode='constant')

                                prediction = np.uint8(
                                    np.argmax(prediction, axis=-1))
                                predictions.append(prediction)

                            prediction_arr = np.transpose(
                                np.asarray(predictions, dtype=np.uint8),
                                (1, 2, 0))

                        elif exp_config.data_mode == '3D':

                            pixel_size = (img_dat[2].structarr['pixdim'][1],
                                          img_dat[2].structarr['pixdim'][2],
                                          img_dat[2].structarr['pixdim'][3])

                            scale_vector = (pixel_size[0] /
                                            exp_config.target_resolution[0],
                                            pixel_size[1] /
                                            exp_config.target_resolution[1],
                                            pixel_size[2] /
                                            exp_config.target_resolution[2])

                            vol_scaled = transform.rescale(img,
                                                           scale_vector,
                                                           order=1,
                                                           preserve_range=True,
                                                           multichannel=False,
                                                           mode='constant')

                            nz_max = exp_config.image_size[2]
                            slice_vol = np.zeros((nx, ny, nz_max),
                                                 dtype=np.float32)

                            nz_curr = vol_scaled.shape[2]
                            stack_from = (nz_max - nz_curr) // 2
                            stack_counter = stack_from

                            x, y, z = vol_scaled.shape

                            x_s = (x - nx) // 2
                            y_s = (y - ny) // 2
                            x_c = (nx - x) // 2
                            y_c = (ny - y) // 2

                            for zz in range(nz_curr):

                                slice_rescaled = vol_scaled[:, :, zz]

                                if x > nx and y > ny:
                                    slice_cropped = slice_rescaled[x_s:x_s +
                                                                   nx,
                                                                   y_s:y_s +
                                                                   ny]
                                else:
                                    slice_cropped = np.zeros((nx, ny))
                                    if x <= nx and y > ny:
                                        slice_cropped[
                                            x_c:x_c +
                                            x, :] = slice_rescaled[:, y_s:y_s +
                                                                   ny]
                                    elif x > nx and y <= ny:
                                        slice_cropped[:, y_c:y_c +
                                                      y] = slice_rescaled[
                                                          x_s:x_s + nx, :]

                                    else:
                                        slice_cropped[x_c:x_c + x, y_c:y_c +
                                                      y] = slice_rescaled[:, :]

                                slice_vol[:, :, stack_counter] = slice_cropped
                                stack_counter += 1

                            stack_to = stack_counter

                            network_input = np.float32(
                                np.reshape(slice_vol, (1, nx, ny, nz_max, 1)))

                            start_time = time.time()
                            mask_out, logits_out = sess.run(
                                [mask_pl, softmax_pl],
                                feed_dict={images_pl: network_input})

                            logging.info('Classified 3D: %f secs' %
                                         (time.time() - start_time))

                            prediction_nzs = logits_out[0, :, :,
                                                        stack_from:stack_to,
                                                        ...]  # non-zero-slices

                            if not prediction_nzs.shape[2] == nz_curr:
                                raise ValueError('sizes mismatch')

                            # ASSEMBLE BACK THE SLICES
                            prediction_scaled = np.zeros(
                                list(vol_scaled.shape) +
                                [num_channels
                                 ])  # last dim is for logits classes

                            # insert cropped region into original image again
                            if x > nx and y > ny:
                                prediction_scaled[x_s:x_s + nx,
                                                  y_s:y_s + ny, :,
                                                  ...] = prediction_nzs
                            else:
                                if x <= nx and y > ny:
                                    prediction_scaled[:, y_s:y_s + ny, :,
                                                      ...] = prediction_nzs[
                                                          x_c:x_c + x, :, :,
                                                          ...]
                                elif x > nx and y <= ny:
                                    prediction_scaled[
                                        x_s:x_s +
                                        nx, :, :...] = prediction_nzs[:,
                                                                      y_c:y_c +
                                                                      y, :...]
                                else:
                                    prediction_scaled[:, :, :
                                                      ...] = prediction_nzs[
                                                          x_c:x_c + x,
                                                          y_c:y_c + y, :...]

                            logging.info('Prediction_scaled mean %f' %
                                         (np.mean(prediction_scaled)))

                            prediction = transform.resize(
                                prediction_scaled,
                                (mask.shape[0], mask.shape[1], mask.shape[2],
                                 num_channels),
                                order=1,
                                preserve_range=True,
                                mode='constant')
                            prediction = np.argmax(prediction, axis=-1)
                            prediction_arr = np.asarray(prediction,
                                                        dtype=np.uint8)

                        # This is the same for 2D and 3D again
                        if do_postprocessing:
                            prediction_arr = image_utils.keep_largest_connected_components(
                                prediction_arr)

                        elapsed_time = time.time() - start_time
                        total_time += elapsed_time
                        total_volumes += 1

                        logging.info('Evaluation of volume took %f secs.' %
                                     elapsed_time)

                        if frame == ED_frame:
                            frame_suffix = '_ED'
                        elif frame == ES_frame:
                            frame_suffix = '_ES'
                        else:
                            raise ValueError(
                                'Frame doesnt correspond to ED or ES. frame = %d, ED = %d, ES = %d'
                                % (frame, ED_frame, ES_frame))

                        # Save prediced mask
                        out_file_name = os.path.join(
                            output_folder, 'prediction',
                            'patient' + patient_id + frame_suffix + '.nii.gz')
                        if gt_exists:
                            out_affine = mask_dat[1]
                            out_header = mask_dat[2]
                        else:
                            out_affine = img_dat[1]
                            out_header = img_dat[2]

                        logging.info('saving to: %s' % out_file_name)
                        utils.save_nii(out_file_name, prediction_arr,
                                       out_affine, out_header)

                        # Save image data to the same folder for convenience
                        image_file_name = os.path.join(
                            output_folder, 'image',
                            'patient' + patient_id + frame_suffix + '.nii.gz')
                        logging.info('saving to: %s' % image_file_name)
                        utils.save_nii(image_file_name, img_dat[0], out_affine,
                                       out_header)

                        if gt_exists:

                            # Save GT image
                            gt_file_name = os.path.join(
                                output_folder, 'ground_truth', 'patient' +
                                patient_id + frame_suffix + '.nii.gz')
                            logging.info('saving to: %s' % gt_file_name)
                            utils.save_nii(gt_file_name, mask, out_affine,
                                           out_header)

                            # Save difference mask between predictions and ground truth
                            difference_mask = np.where(
                                np.abs(prediction_arr - mask) > 0, [1], [0])
                            difference_mask = np.asarray(difference_mask,
                                                         dtype=np.uint8)
                            diff_file_name = os.path.join(
                                output_folder, 'difference', 'patient' +
                                patient_id + frame_suffix + '.nii.gz')
                            logging.info('saving to: %s' % diff_file_name)
                            utils.save_nii(diff_file_name, difference_mask,
                                           out_affine, out_header)

        logging.info('Average time per volume: %f' %
                     (total_time / total_volumes))

    return init_iteration
def prepare_data(input_folder,
                 output_file,
                 size,
                 target_resolution,
                 cv_fold_num):

    # =======================
    # =======================
    image_folder = os.path.join(input_folder, 'Prostate-3T')
    mask_folder = os.path.join(input_folder, 'NCI_ISBI_Challenge-Prostate3T_Training_Segmentations')

    # =======================
    # =======================
    hdf5_file = h5py.File(output_file, "w")

    # =======================
    # =======================
    logging.info('Counting files and parsing meta data...')
    folder_list = get_patient_folders(image_folder,
                                      folder_base='Prostate3T-01',
                                      cv_fold_number = cv_fold_num)
    
    num_slices = count_slices(image_folder,
                              folder_base='Prostate3T-01',
                              cv_fold_number = cv_fold_num)
    
    nx, ny = size
    n_test = num_slices['test']
    n_train = num_slices['train']
    n_val = num_slices['validation']

    # =======================
    # =======================
    print('Debug: Check if sets add up to correct value:')
    print(n_train, n_val, n_test, n_train + n_val + n_test)

    # =======================
    # Create datasets for images and masks
    # =======================
    data = {}
    for tt, num_points in zip(['test', 'train', 'validation'], [n_test, n_train, n_val]):

        if num_points > 0:
            data['images_%s' % tt] = hdf5_file.create_dataset("images_%s" % tt, [num_points] + list(size), dtype=np.float32)
            data['masks_%s' % tt] = hdf5_file.create_dataset("masks_%s" % tt, [num_points] + list(size), dtype=np.uint8)

    mask_list = {'test': [], 'train': [], 'validation': []}
    img_list = {'test': [], 'train': [], 'validation': []}
    nx_list = {'test': [], 'train': [], 'validation': []}
    ny_list = {'test': [], 'train': [], 'validation': []}
    nz_list = {'test': [], 'train': [], 'validation': []}
    px_list = {'test': [], 'train': [], 'validation': []}
    py_list = {'test': [], 'train': [], 'validation': []}
    pz_list = {'test': [], 'train': [], 'validation': []}
    pat_names_list = {'test': [], 'train': [], 'validation': []}

    # =======================
    # =======================
    logging.info('Parsing image files')
    for train_test in ['test', 'train', 'validation']:

        write_buffer = 0
        counter_from = 0

        patient_counter = 0

        for folder in folder_list[train_test]:

            patient_counter += 1

            logging.info('================================')
            logging.info('Doing: %s' % folder)
            pat_names_list[train_test].append(str(folder.split('-')[-1]))

            lstFilesDCM = []  # create an empty list
            
            for dirName, subdirList, fileList in os.walk(folder):
            
                # fileList.sort()
                for filename in fileList:
                
                    if ".dcm" in filename.lower():  # check whether the file's DICOM
                        lstFilesDCM.append(os.path.join(dirName, filename))

            # Get ref file
            RefDs = dicom.read_file(lstFilesDCM[0])

            # Load dimensions based on the number of rows, columns, and slices (along the Z axis)
            ConstPixelDims = (int(RefDs.Rows), int(RefDs.Columns), len(lstFilesDCM))

            # Load spacing values (in mm)
            pixel_size = (float(RefDs.PixelSpacing[0]), float(RefDs.PixelSpacing[1]), float(RefDs.SliceThickness))
            px_list[train_test].append(float(RefDs.PixelSpacing[0]))
            py_list[train_test].append(float(RefDs.PixelSpacing[1]))
            pz_list[train_test].append(float(RefDs.SliceThickness))

            print('PixelDims')
            print(ConstPixelDims)
            print('PixelSpacing')
            print(pixel_size)

            # The array is sized based on 'ConstPixelDims'
            img = np.zeros(ConstPixelDims, dtype=RefDs.pixel_array.dtype)

            # loop through all the DICOM files
            for filenameDCM in lstFilesDCM:

                # read the file
                ds = dicom.read_file(filenameDCM)

                # ======
                # store the raw image data
                # img[:, :, lstFilesDCM.index(filenameDCM)] = ds.pixel_array
                # index number field is not set correctly!
                # instead instance number is the slice number.
                # ======
                img[:, :, ds.InstanceNumber - 1] = ds.pixel_array
                
            # ================================
            # save as nifti, this sets the affine transformation as an identity matrix
            # ================================    
            nifti_img_path = lstFilesDCM[0][:lstFilesDCM[0].rfind('/')+1]
            utils.save_nii(img_path = nifti_img_path + 'img.nii.gz', data = img, affine = np.eye(4))
    
            # ================================
            # do bias field correction
            # ================================
            input_img = nifti_img_path + 'img.nii.gz'
            output_img = nifti_img_path + 'img_n4.nii.gz'
            subprocess.call(["/usr/bmicnas01/data-biwi-01/bmicdatasets/Sharing/N4_th", input_img, output_img])
    
            # ================================    
            # read bias corrected image
            # ================================    
            img = utils.load_nii(img_path = nifti_img_path + 'img_n4.nii.gz')[0]

            # ================================    
            # normalize the image
            # ================================    
            img = utils.normalise_image(img, norm_type='div_by_max')

            # ================================    
            # read the labels
            # ================================    
            mask_path = os.path.join(mask_folder, folder.split('/')[-1] + '.nrrd')
            mask, options = nrrd.read(mask_path)

            # fix swap axis
            mask = np.swapaxes(mask, 0, 1)
            
            # ================================
            # save as nifti, this sets the affine transformation as an identity matrix
            # ================================    
            utils.save_nii(img_path = nifti_img_path + 'lbl.nii.gz', data = mask, affine = np.eye(4))
            
            nx_list[train_test].append(mask.shape[0])
            ny_list[train_test].append(mask.shape[1])
            nz_list[train_test].append(mask.shape[2])

            print('mask.shape')
            print(mask.shape)
            print('img.shape')
            print(img.shape)

            ### PROCESSING LOOP FOR SLICE-BY-SLICE 2D DATA ###################
            scale_vector = [pixel_size[0] / target_resolution[0],
                            pixel_size[1] / target_resolution[1]]

            for zz in range(img.shape[2]):

                slice_img = np.squeeze(img[:, :, zz])
                slice_rescaled = transform.rescale(slice_img,
                                                   scale_vector,
                                                   order=1,
                                                   preserve_range=True,
                                                   multichannel=False,
                                                   mode = 'constant')

                slice_mask = np.squeeze(mask[:, :, zz])
                mask_rescaled = transform.rescale(slice_mask,
                                                  scale_vector,
                                                  order=0,
                                                  preserve_range=True,
                                                  multichannel=False,
                                                  mode='constant')

                slice_cropped = utils.crop_or_pad_slice_to_size(slice_rescaled, nx, ny)
                mask_cropped = utils.crop_or_pad_slice_to_size(mask_rescaled, nx, ny)

                img_list[train_test].append(slice_cropped)
                mask_list[train_test].append(mask_cropped)

                write_buffer += 1

                # Writing needs to happen inside the loop over the slices
                if write_buffer >= MAX_WRITE_BUFFER:

                    counter_to = counter_from + write_buffer
                    _write_range_to_hdf5(data, train_test, img_list, mask_list, counter_from, counter_to)
                    _release_tmp_memory(img_list, mask_list, train_test)

                    # reset stuff for next iteration
                    counter_from = counter_to
                    write_buffer = 0


        logging.info('Writing remaining data')
        counter_to = counter_from + write_buffer

        _write_range_to_hdf5(data, train_test, img_list, mask_list, counter_from, counter_to)
        _release_tmp_memory(img_list, mask_list, train_test)

    # Write the small datasets
    for tt in ['test', 'train', 'validation']:
        hdf5_file.create_dataset('nx_%s' % tt, data=np.asarray(nx_list[tt], dtype=np.uint16))
        hdf5_file.create_dataset('ny_%s' % tt, data=np.asarray(ny_list[tt], dtype=np.uint16))
        hdf5_file.create_dataset('nz_%s' % tt, data=np.asarray(nz_list[tt], dtype=np.uint16))
        hdf5_file.create_dataset('px_%s' % tt, data=np.asarray(px_list[tt], dtype=np.float32))
        hdf5_file.create_dataset('py_%s' % tt, data=np.asarray(py_list[tt], dtype=np.float32))
        hdf5_file.create_dataset('pz_%s' % tt, data=np.asarray(pz_list[tt], dtype=np.float32))
        hdf5_file.create_dataset('patnames_%s' % tt, data=np.asarray(pat_names_list[tt], dtype="S10"))
    
    # After test train loop:
    hdf5_file.close()
Пример #23
0
def prepare_data(
        input_folder,
        output_file,
        mode,
        size,  # for 3d: (nz, nx, ny), for 2d: (nx, ny)
        target_resolution,  # for 3d: (px, py, pz), for 2d: (px, py)
        cv_fold_num):
    '''
    Main function that prepares a dataset from the raw challenge data to an hdf5 dataset
    '''

    assert (mode in ['2D', '3D']), 'Unknown mode: %s' % mode
    if mode == '2D' and not len(size) == 2:
        raise AssertionError('Inadequate number of size parameters')
    if mode == '3D' and not len(size) == 3:
        raise AssertionError('Inadequate number of size parameters')
    if mode == '2D' and not len(target_resolution) == 2:
        raise AssertionError(
            'Inadequate number of target resolution parameters')
    if mode == '3D' and not len(target_resolution) == 3:
        raise AssertionError(
            'Inadequate number of target resolution parameters')

    # ============
    # create an empty hdf5 file
    # ============
    hdf5_file = h5py.File(output_file, "w")

    # ============
    # create empty lists for filling header info
    # ============
    diag_list = {'test': [], 'train': [], 'validation': []}
    height_list = {'test': [], 'train': [], 'validation': []}
    weight_list = {'test': [], 'train': [], 'validation': []}
    patient_id_list = {'test': [], 'train': [], 'validation': []}
    cardiac_phase_list = {'test': [], 'train': [], 'validation': []}
    nx_list = {'test': [], 'train': [], 'validation': []}
    ny_list = {'test': [], 'train': [], 'validation': []}
    nz_list = {'test': [], 'train': [], 'validation': []}
    px_list = {'test': [], 'train': [], 'validation': []}
    py_list = {'test': [], 'train': [], 'validation': []}
    pz_list = {'test': [], 'train': [], 'validation': []}

    file_list = {'test': [], 'train': [], 'validation': []}
    num_slices = {'test': 0, 'train': 0, 'validation': 0}

    # ============
    # go through all images and get header info.
    # one round of parsing is done just to get all the header info. The size info is used to create empty fields for the images and labels, with the required sizes.
    # Then, another round of reading the images and labels is done, which are pre-processed and written into the hdf5 file
    # ============
    for folder in os.listdir(input_folder):

        folder_path = os.path.join(input_folder, folder)

        if os.path.isdir(folder_path):

            # ============
            # train_test_validation split
            # ============
            train_test = test_train_val_split(patient_id=int(folder[-3:]),
                                              cv_fold_number=cv_fold_num)

            infos = {}
            for line in open(os.path.join(folder_path, 'Info.cfg')):
                label, value = line.split(':')
                infos[label] = value.rstrip('\n').lstrip(' ')

            patient_id = folder.lstrip('patient')

            # ============
            # reading this patient's image and collecting header info
            # ============
            for file in glob.glob(
                    os.path.join(folder_path, 'patient???_frame??_n4.nii.gz')):

                # ============
                # list with file paths
                # ============
                file_list[train_test].append(file)

                diag_list[train_test].append(diagnosis_dict[infos['Group']])
                weight_list[train_test].append(infos['Weight'])
                height_list[train_test].append(infos['Height'])

                patient_id_list[train_test].append(patient_id)

                systole_frame = int(infos['ES'])
                diastole_frame = int(infos['ED'])

                file_base = file.split('.')[0]
                frame = int(file_base.split('frame')[-1][:-3])
                if frame == systole_frame:
                    cardiac_phase_list[train_test].append(1)  # 1 == systole
                elif frame == diastole_frame:
                    cardiac_phase_list[train_test].append(2)  # 2 == diastole
                else:
                    cardiac_phase_list[train_test].append(
                        0)  # 0 means other phase

                nifty_img = nib.load(file)
                nx_list[train_test].append(nifty_img.shape[0])
                ny_list[train_test].append(nifty_img.shape[1])
                nz_list[train_test].append(nifty_img.shape[2])
                num_slices[train_test] += nifty_img.shape[2]
                py_list[train_test].append(
                    nifty_img.header.structarr['pixdim'][2])
                px_list[train_test].append(
                    nifty_img.header.structarr['pixdim'][1])
                pz_list[train_test].append(
                    nifty_img.header.structarr['pixdim'][3])

    # ============
    # writing the small datasets
    # ============
    for tt in ['test', 'train', 'validation']:
        hdf5_file.create_dataset('diagnosis_%s' % tt,
                                 data=np.asarray(diag_list[tt],
                                                 dtype=np.uint8))
        hdf5_file.create_dataset('weight_%s' % tt,
                                 data=np.asarray(weight_list[tt],
                                                 dtype=np.float32))
        hdf5_file.create_dataset('height_%s' % tt,
                                 data=np.asarray(height_list[tt],
                                                 dtype=np.float32))
        hdf5_file.create_dataset('patient_id_%s' % tt,
                                 data=np.asarray(patient_id_list[tt],
                                                 dtype=np.uint8))
        hdf5_file.create_dataset('cardiac_phase_%s' % tt,
                                 data=np.asarray(cardiac_phase_list[tt],
                                                 dtype=np.uint8))
        hdf5_file.create_dataset('nz_%s' % tt,
                                 data=np.asarray(nz_list[tt], dtype=np.uint16))
        hdf5_file.create_dataset('ny_%s' % tt,
                                 data=np.asarray(ny_list[tt], dtype=np.uint16))
        hdf5_file.create_dataset('nx_%s' % tt,
                                 data=np.asarray(nx_list[tt], dtype=np.uint16))
        hdf5_file.create_dataset('py_%s' % tt,
                                 data=np.asarray(py_list[tt],
                                                 dtype=np.float32))
        hdf5_file.create_dataset('px_%s' % tt,
                                 data=np.asarray(px_list[tt],
                                                 dtype=np.float32))
        hdf5_file.create_dataset('pz_%s' % tt,
                                 data=np.asarray(pz_list[tt],
                                                 dtype=np.float32))

    # ============
    # setting sizes according to 2d or 3d
    # ============
    if mode == '3D':  # size [num_patients, nz, nx, ny]
        nz_max, nx, ny = size
        n_train = len(file_list['train'])  # number of patients
        n_test = len(file_list['test'])
        n_val = len(file_list['validation'])

    elif mode == '2D':  # size [num_z_slices_across_all_patients, nx, ny]
        nx, ny = size
        n_test = num_slices['test']
        n_train = num_slices['train']
        n_val = num_slices['validation']

    else:
        raise AssertionError('Wrong mode setting. This should never happen.')

    # ============
    # creating datasets for images and labels
    # ============
    data = {}
    for tt, num_points in zip(['test', 'train', 'validation'],
                              [n_test, n_train, n_val]):

        if num_points > 0:
            data['images_%s' % tt] = hdf5_file.create_dataset(
                "images_%s" % tt, [num_points] + list(size), dtype=np.float32)
            data['labels_%s' % tt] = hdf5_file.create_dataset(
                "labels_%s" % tt, [num_points] + list(size), dtype=np.uint8)

    image_list = {'test': [], 'train': [], 'validation': []}
    label_list = {'test': [], 'train': [], 'validation': []}

    for train_test in ['test', 'train', 'validation']:

        write_buffer = 0
        counter_from = 0
        patient_counter = 0

        for image_file in file_list[train_test]:

            patient_counter += 1

            logging.info('============================================')
            logging.info('Doing: %s' % image_file)

            # ============
            # read image
            # ============
            image_dat = utils.load_nii(image_file)
            image = image_dat[0].copy()

            # ============
            # normalize the image to be between 0 and 1
            # ============
            image = utils.normalise_image(image, norm_type='div_by_max')

            # ============
            # read label
            # ============
            file_base = image_file.split('_n4.nii.gz')[0]
            label_file = file_base + '_gt.nii.gz'
            label_dat = utils.load_nii(label_file)
            label = label_dat[0].copy()

            # ============
            # set RV label to 1 and other labels to 0, as the RVSC dataset only has labels for the RV
            # original labels: 0 bachground, 1 right ventricle, 2 myocardium, 3 left ventricle
            # ============
            # label[label!=1] = 0

            # ============
            # original pixel size (px, py, pz)
            # ============
            pixel_size = (image_dat[2].structarr['pixdim'][1],
                          image_dat[2].structarr['pixdim'][2],
                          image_dat[2].structarr['pixdim'][3])

            # ========================================================================
            # PROCESSING LOOP FOR 3D DATA
            # ========================================================================
            if mode == '3D':

                # rescaling ratio
                scale_vector = [
                    pixel_size[0] / target_resolution[0],
                    pixel_size[1] / target_resolution[1],
                    pixel_size[2] / target_resolution[2]
                ]

                # ==============================
                # rescale image and label
                # ==============================
                image_scaled = transform.rescale(image,
                                                 scale_vector,
                                                 order=1,
                                                 preserve_range=True,
                                                 multichannel=False,
                                                 mode='constant')
                label_scaled = transform.rescale(label,
                                                 scale_vector,
                                                 order=0,
                                                 preserve_range=True,
                                                 multichannel=False,
                                                 mode='constant')

                # ==============================
                # ==============================
                image_scaled = utils.crop_or_pad_volume_to_size_along_z(
                    image_scaled, nz_max)
                label_scaled = utils.crop_or_pad_volume_to_size_along_z(
                    label_scaled, nz_max)

                # ==============================
                # nz_max is the z-dimension provided in the 'size' parameter
                # ==============================
                image_vol = np.zeros((nx, ny, nz_max), dtype=np.float32)
                label_vol = np.zeros((nx, ny, nz_max), dtype=np.uint8)

                # ===============================
                # going through each z slice
                # ===============================
                for zz in range(nz_max):

                    image_slice = image_scaled[:, :, zz]
                    label_slice = label_scaled[:, :, zz]

                    # cropping / padding with zeros the x-y slice at this z location
                    image_slice_cropped = utils.crop_or_pad_slice_to_size(
                        image_slice, nx, ny)
                    label_slice_cropped = utils.crop_or_pad_slice_to_size(
                        label_slice, nx, ny)

                    image_vol[:, :, zz] = image_slice_cropped
                    label_vol[:, :, zz] = label_slice_cropped

                # ===============================
                # swap axes to maintain consistent orientation as compared to 2d pre-processing
                # ===============================
                image_vol = image_vol.swapaxes(0, 2).swapaxes(1, 2)
                label_vol = label_vol.swapaxes(0, 2).swapaxes(1, 2)

                # ===============================
                # append to list that will be written to the hdf5 file
                # ===============================
                image_list[train_test].append(image_vol)
                label_list[train_test].append(label_vol)

                write_buffer += 1

                # ===============================
                # writing the images and labels pre-processed so far to the hdf5 file
                # ===============================
                if write_buffer >= MAX_WRITE_BUFFER:

                    counter_to = counter_from + write_buffer
                    _write_range_to_hdf5(data, train_test, image_list,
                                         label_list, counter_from, counter_to)
                    _release_tmp_memory(image_list, label_list, train_test)

                    # reset stuff for next iteration
                    counter_from = counter_to
                    write_buffer = 0

            # ========================================================================
            # PROCESSING LOOP FOR SLICE-BY-SLICE 2D DATA
            # ========================================================================
            elif mode == '2D':

                scale_vector = [
                    pixel_size[0] / target_resolution[0],
                    pixel_size[1] / target_resolution[1]
                ]

                # ===============================
                # go through each z slice, rescale and crop and append.
                # in this process, the z axis will become the zeroth axis
                # ===============================
                for zz in range(image.shape[2]):

                    image_slice = np.squeeze(image[:, :, zz])
                    label_slice = np.squeeze(label[:, :, zz])

                    image_slice_rescaled = transform.rescale(
                        image_slice,
                        scale_vector,
                        order=1,
                        preserve_range=True,
                        multichannel=False,
                        mode='constant')
                    label_slice_rescaled = transform.rescale(
                        label_slice,
                        scale_vector,
                        order=0,
                        preserve_range=True,
                        multichannel=False,
                        mode='constant')

                    image_slice_cropped = utils.crop_or_pad_slice_to_size(
                        image_slice_rescaled, nx, ny)
                    label_slice_cropped = utils.crop_or_pad_slice_to_size(
                        label_slice_rescaled, nx, ny)

                    image_list[train_test].append(image_slice_cropped)
                    label_list[train_test].append(label_slice_cropped)

                    write_buffer += 1

                    # Writing needs to happen inside the loop over the slices
                    if write_buffer >= MAX_WRITE_BUFFER:

                        counter_to = counter_from + write_buffer
                        _write_range_to_hdf5(data, train_test, image_list,
                                             label_list, counter_from,
                                             counter_to)
                        _release_tmp_memory(image_list, label_list, train_test)

                        # reset stuff for next iteration
                        counter_from = counter_to
                        write_buffer = 0

        logging.info('Writing remaining data')
        counter_to = counter_from + write_buffer

        _write_range_to_hdf5(data, train_test, image_list, label_list,
                             counter_from, counter_to)
        _release_tmp_memory(image_list, label_list, train_test)

    # After test train loop:
    hdf5_file.close()
def prepare_data(input_folder, output_file, idx_start, idx_end, protocol, size,
                 target_resolution, preprocessing_folder):
    # ========================
    # read the filenames
    # ========================
    filenames = sorted(glob.glob(input_folder + '*.zip'))
    logging.info('Number of images in the dataset: %s' % str(len(filenames)))

    # =======================
    # create a hdf5 file
    # =======================
    # hdf5_file = h5py.File(output_file, "w")
    #
    # # ===============================
    # # Create datasets for images and labels
    # # ===============================
    # data = {}
    # num_subjects = idx_end - idx_start
    #
    # data['images'] = hdf5_file.create_dataset("images", [num_subjects] + list(size), dtype=np.float32)
    # data['labels'] = hdf5_file.create_dataset("labels", [num_subjects] + list(size), dtype=np.uint8)
    #
    # # ===============================
    # initialize lists
    # ===============================
    label_list = []
    image_list = []
    nx_list = []
    ny_list = []
    nz_list = []
    px_list = []
    py_list = []
    pz_list = []
    pat_names_list = []

    # ===============================
    # initiate counter
    # ===============================
    patient_counter = 0

    # ===============================
    # iterate through the requested indices
    # ===============================
    for idx in range(idx_start, idx_end):
        logging.info('Volume {} of {}...'.format(idx, idx_end))

        # ==================
        # get file paths
        # ==================
        patient_name, image_path, label_path = get_image_and_label_paths(
            filenames[idx], protocol, preprocessing_folder)

        # ============
        # read the image and normalize it to be between 0 and 1
        # ============
        image, _, image_hdr = utils.load_nii(image_path)

        # ==================
        # read the label file
        # ==================
        label, _, _ = utils.load_nii(label_path)
        label = utils.group_segmentation_classes(
            label)  # group the segmentation classes as required

        # # ==================
        # # collect some header info.
        # # ==================
        # px_list.append(float(image_hdr.get_zooms()[0]))
        # py_list.append(float(image_hdr.get_zooms()[1]))
        # pz_list.append(float(image_hdr.get_zooms()[2]))
        # nx_list.append(image.shape[0])
        # ny_list.append(image.shape[1])
        # nz_list.append(image.shape[2])
        # pat_names_list.append(patient_name)

        # ==================
        # crop volume along all axes from the ends (as there are several zeros towards the ends)
        # ==================
        image = utils.crop_or_pad_volume_to_size_along_x(image, 256)
        label = utils.crop_or_pad_volume_to_size_along_x(label, 256)
        image = utils.crop_or_pad_volume_to_size_along_y(image, 256)
        label = utils.crop_or_pad_volume_to_size_along_y(label, 256)
        image = utils.crop_or_pad_volume_to_size_along_z(image, 256)
        label = utils.crop_or_pad_volume_to_size_along_z(label, 256)

        # ==================
        # normalize the image
        # ==================
        image_normalized = utils.normalise_image(image, norm_type='div_by_max')

        # ======================================================
        # rescale, crop / pad to make all images of the required size and resolution
        # ======================================================
        scale_vector = [
            image_hdr.get_zooms()[0] / target_resolution[0],
            image_hdr.get_zooms()[1] / target_resolution[1],
            image_hdr.get_zooms()[2] / target_resolution[2]
        ]

        image_rescaled = transform.rescale(image_normalized,
                                           scale_vector,
                                           order=1,
                                           preserve_range=True,
                                           multichannel=False,
                                           mode='constant')

        # label_onehot = utils.make_onehot(label, nlabels=15)
        #
        # label_onehot_rescaled = transform.rescale(label_onehot,
        #                                           scale_vector,
        #                                           order=1,
        #                                           preserve_range=True,
        #                                           multichannel=True,
        #                                           mode='constant')
        #
        # label_rescaled = np.argmax(label_onehot_rescaled, axis=-1)
        #
        # # ============
        # # the images and labels have been rescaled to the desired resolution.
        # # write them to the hdf5 file now.
        # # ============
        # image_list.append(image_rescaled)
        # label_list.append(label_rescaled)

        # ============
        # write to file
        # ============
        # image_rescaled
        volume_dir = os.path.join(preprocessing_folder,
                                  'volume_{:06d}'.format(idx))
        os.makedirs(volume_dir, exist_ok=True)
        for i in range(size[1]):
            slice_path = os.path.join(volume_dir,
                                      'slice_{:06d}.jpeg'.format(i))
            slice = image_rescaled[:, i, :] * 255
            image = Image.fromarray(slice.astype(np.uint8))
            image.save(slice_path)
Пример #25
0
def prepare_data(input_folder, output_file, size, input_channels,
                 target_resolution):

    hdf5_file = h5py.File(output_file, "w")

    file_list = []

    logging.info('Counting files and parsing meta data...')

    pid = 0
    for folder in os.listdir(input_folder + '/img'):
        print(get_im_id(folder))
        # train_test = test_train_val_split(pid)
        pid = pid + 1
        file_list.append(get_im_id(folder))

    n_train = len(file_list)

    print('Debug: Check if sets add up to correct value:')
    print(n_train)

    # Create datasets for images and masks
    data = {}
    for tt, num_points in zip(['train'], [n_train]):

        if num_points > 0:
            print([num_points] + list(size) + [input_channels])
            if input_channels != 1:
                data['images_%s' % tt] = hdf5_file.create_dataset(
                    "images_%s" % tt,
                    [num_points] + list(size) + [input_channels],
                    dtype=np.float32)
                data['masks_%s' % tt] = hdf5_file.create_dataset(
                    "masks_%s" % tt, [num_points] + list(size), dtype=np.uint8)
                data['pids_%s' % tt] = hdf5_file.create_dataset(
                    "pids_%s" % tt, [num_points],
                    dtype=h5py.special_dtype(vlen=str))
            else:
                data['images_%s' % tt] = hdf5_file.create_dataset(
                    "images_%s" % tt, [num_points] + list(size),
                    dtype=np.float32)
                data['masks_%s' % tt] = hdf5_file.create_dataset(
                    "masks_%s" % tt, [num_points] + list(size), dtype=np.uint8)
                data['pids_%s' % tt] = hdf5_file.create_dataset(
                    "pids_%s" % tt, [num_points],
                    dtype=h5py.special_dtype(vlen=str))

    mask_list = []
    img_list = []
    pids_list = []

    logging.info('Parsing image files')

    #get max dimension in z-axis
    maxX = 0
    maxY = 0
    maxZ = 0
    i = 0
    for train_test in ['train']:
        for file in file_list:
            print("Doing file {}".format(i))
            i += 1

            baseFilePath = os.path.join(input_folder, 'img',
                                        'img' + file + '.nii.gz')
            img_dat, _, img_header = utils.load_nii(baseFilePath)

            maxX = max(maxX, img_dat.shape[0])
            maxY = max(maxY, img_dat.shape[1])
            maxZ = max(maxZ, img_dat.shape[2])

    print("Max x: {}, y: {}, z: {}".format(maxX, maxY, maxZ))

    for train_test in ['train']:

        write_buffer = 0
        counter_from = 0

        for file in file_list:

            logging.info(
                '-----------------------------------------------------------')
            logging.info('Doing: %s' % file)

            patient_id = file

            baseFilePath = os.path.join(input_folder, 'img',
                                        'img' + file + '.nii.gz')
            img, _, img_header = utils.load_nii(baseFilePath)
            mask, _, _ = utils.load_nii(
                os.path.join(input_folder, 'label',
                             'label' + file + '.nii.gz'))

            # print("mask sum ", np.sum(mask))

            img = pad_slice_to_size(img, (512, 512, 256))
            mask = pad_slice_to_size(mask, (512, 512, 256))

            print_info(img, "X")
            print_info(mask, "Y")

            # print("mask sum ", np.sum(mask))
            scale_vector = target_resolution

            if scale_vector != [1.0]:
                # print(img.shape)
                # #img = transform.resize(img, size)
                # img = transform.rescale(img, scale_vector[0], anti_aliasing=False, preserve_range=True)
                # #mask = transform.resize(mask, size)
                # mask = rescale_labels(mask, scale_vector[0])

                img = F.interpolate(
                    torch.from_numpy(img)[None, None, :, :, :].float(),
                    size=size,
                    mode='trilinear',
                    align_corners=True).numpy()[0, 0, ...]
                mask = rescale_labels(mask, scale_vector[0], size)

                np.save('checkpoints/images/img.npy', img)
                np.save('checkpoints/images/mask.npy', mask)

                print_info(img, "x")
                print_info(mask, "y", unique=True)

            # print("mask sum ", np.sum(mask))
            img = normalise_image(img)

            img_list.append(img)
            mask_list.append(mask)
            pids_list.append(patient_id)

            write_buffer += 1

            if write_buffer >= MAX_WRITE_BUFFER:

                counter_to = counter_from + write_buffer
                _write_range_to_hdf5(data, 'train', img_list, mask_list,
                                     pids_list, counter_from, counter_to)
                _release_tmp_memory(img_list, mask_list, pids_list, 'train')

                # reset stuff for next iteration
                counter_from = counter_to
                write_buffer = 0

        logging.info('Writing remaining data')
        counter_to = counter_from + write_buffer

        if len(file_list) > 0:
            _write_range_to_hdf5(data, 'train', img_list, mask_list, pids_list,
                                 counter_from, counter_to)
        _release_tmp_memory(img_list, mask_list, pids_list, 'train')

    # After test train loop:
    hdf5_file.close()
def prepare_data(input_folder, output_file, input_channels):
    '''
    Main function that prepares a dataset from the raw challenge data to an hdf5 dataset
    '''

    hdf5_file = h5py.File(output_file, "w")

    file_list = {'test': [], 'train': [], 'validation': []}
    num_slices = {'test': 0, 'train': 0, 'validation': 0}

    logging.info('Counting files and parsing meta data...')

    pid = 0
    for folder in os.listdir(input_folder):
        print(folder)
        train_test = test_train_val_split(pid)
        pid = pid + 1
        file_list[train_test].append(folder)

    n_train = len(file_list['train'])
    n_test = len(file_list['test'])
    n_val = len(file_list['validation'])

    print('Debug: Check if sets add up to correct value:')
    print(n_train, n_val, n_test, n_train + n_val + n_test)

    # Create datasets for images and masks
    data = {}
    for tt, num_points in zip(['test', 'train', 'validation'],
                              [n_test, n_train, n_val]):

        if num_points > 0:
            data['images_%s' % tt] = hdf5_file.create_dataset(
                "images_%s" % tt,
                [num_points] + [160, 192, 160] + [input_channels],
                dtype=np.float32)
            data['pids_%s' % tt] = hdf5_file.create_dataset(
                "pids_%s" % tt, [num_points],
                dtype=h5py.special_dtype(vlen=str))
            data['xOffsets_%s' % tt] = hdf5_file.create_dataset(
                "xOffsets_%s" % tt, [num_points], dtype=np.int)
            data['yOffsets_%s' % tt] = hdf5_file.create_dataset(
                "yOffsets_%s" % tt, [num_points], dtype=np.int)
            data['zOffsets_%s' % tt] = hdf5_file.create_dataset(
                "zOffsets_%s" % tt, [num_points], dtype=np.int)

    img_list = {'test': [], 'train': [], 'validation': []}
    pids_list = {'test': [], 'train': [], 'validation': []}
    xOffsest_list = {'test': [], 'train': [], 'validation': []}
    yOffsest_list = {'test': [], 'train': [], 'validation': []}
    zOffsest_list = {'test': [], 'train': [], 'validation': []}

    logging.info('Parsing image files')

    # Uncomment for calculating the needed image dimension
    # #get max dimension in z-axis
    # maxX = 0
    # maxY = 0
    # maxZ = 0
    # maxXCropped = 0
    # maxYCropped = 0
    # maxZCropped = 0
    # i = 0
    # for train_test in ['test', 'train', 'validation']:
    #     for folder in file_list[train_test]:
    #         print("Doing file {}".format(i))
    #         i += 1
    #
    #         baseFilePath = os.path.join(input_folder, folder, folder)
    #         img_c1, _, img_header = utils.load_nii(baseFilePath + "_t1.nii.gz")
    #         img_c2, _, _ = utils.load_nii(baseFilePath + "_t1ce.nii.gz")
    #         img_c3, _, _ = utils.load_nii(baseFilePath + "_t2.nii.gz")
    #         img_c4, _, _ = utils.load_nii(baseFilePath + "_flair.nii.gz")
    #         img_dat = np.stack((img_c1, img_c2, img_c3, img_c4), 3)
    #
    #         maxX = max(maxX, img_dat.shape[0])
    #         maxY = max(maxY, img_dat.shape[1])
    #         maxZ = max(maxZ, img_dat.shape[2])
    #         img_dat_cropped = crop_volume_allDim(img_dat)
    #         maxXCropped = max(maxXCropped, img_dat_cropped.shape[0])
    #         maxYCropped = max(maxYCropped, img_dat_cropped.shape[1])
    #         maxZCropped = max(maxZCropped, img_dat_cropped.shape[2])
    # print("Max x: {}, y: {}, z: {}".format(maxX, maxY, maxZ))
    # print("Max cropped x: {}, y: {}, z: {}".format(maxXCropped, maxYCropped, maxZCropped))

    for train_test in ['train', 'test', 'validation']:

        write_buffer = 0
        counter_from = 0

        for folder in file_list[train_test]:

            logging.info(
                '-----------------------------------------------------------')
            logging.info('Doing: %s' % folder)

            patient_id = folder

            baseFilePath = os.path.join(input_folder, folder, folder)
            img_c1, _, img_header = utils.load_nii(baseFilePath + "_t1.nii.gz")
            img_c2, _, _ = utils.load_nii(baseFilePath + "_t1ce.nii.gz")
            img_c3, _, _ = utils.load_nii(baseFilePath + "_t2.nii.gz")
            img_c4, _, _ = utils.load_nii(baseFilePath + "_flair.nii.gz")

            img_dat = np.stack((img_c1, img_c2, img_c3, img_c4), 3)

            img, offsets = crop_volume_allDim(img_dat.copy())

            pixel_size = (img_header.structarr['pixdim'][1],
                          img_header.structarr['pixdim'][2],
                          img_header.structarr['pixdim'][3])

            logging.info('Pixel size:')
            logging.info(pixel_size)

            ### PROCESSING LOOP FOR 3D DATA ################################
            img = crop_or_pad_slice_to_size(img, [160, 192, 160],
                                            input_channels)
            img = normalise_image(img)

            img_list[train_test].append(img)
            pids_list[train_test].append(patient_id)
            xOffsest_list[train_test].append(offsets[0])
            yOffsest_list[train_test].append(offsets[1])
            zOffsest_list[train_test].append(offsets[2])

            write_buffer += 1

            if write_buffer >= MAX_WRITE_BUFFER:

                counter_to = counter_from + write_buffer
                _write_range_to_hdf5(data, train_test, img_list, pids_list,
                                     xOffsest_list, yOffsest_list,
                                     zOffsest_list, counter_from, counter_to)
                _release_tmp_memory(img_list, pids_list, xOffsest_list,
                                    yOffsest_list, zOffsest_list, train_test)

                # reset stuff for next iteration
                counter_from = counter_to
                write_buffer = 0

        logging.info('Writing remaining data')
        counter_to = counter_from + write_buffer

        if len(file_list[train_test]) > 0:
            _write_range_to_hdf5(data, train_test, img_list, pids_list,
                                 xOffsest_list, yOffsest_list, zOffsest_list,
                                 counter_from, counter_to)

    # After test train loop:
    hdf5_file.close()
Пример #27
0
def prepare_data(input_folder, preproc_folder, protocol, idx_start, idx_end):

    images = []
    affines = []
    patnames = []
    masks = []

    # ========================
    # read the filenames
    # ========================
    filenames = sorted(glob.glob(input_folder + '*.zip'))
    logging.info('Number of images in the dataset: %s' % str(len(filenames)))

    # ========================
    # iterate through the requested indices
    # ========================
    for idx in range(idx_start, idx_end):

        logging.info(
            '============================================================')

        # ========================
        # get the file name for this subject
        # ========================
        filename = filenames[idx]

        # ========================
        # define how much of the image can be cropped out as it consists of zeros
        # ========================
        x_start = 18
        x_end = -18
        y_start = 28
        y_end = -27
        z_start = 2
        z_end = -34
        # original images are 260 * 311 * 260
        # cropping them down to 224 * 256 * 224

        # ========================
        # read the contents inside the top-level subject directory
        # ========================
        with zipfile.ZipFile(filename, 'r') as zfile:

            # ========================
            # search for the relevant files
            # ========================
            for name in zfile.namelist():

                # ========================
                # search for files inside the T1w directory
                # ========================
                if re.search(r'\/T1w/', name) != None:

                    # ========================
                    # search for .gz files inside the T1w directory
                    # ========================
                    if re.search(r'\.gz$', name) != None:

                        # ========================
                        # get the protocol image
                        # ========================
                        if re.search(protocol + 'acpc_dc_restore_brain',
                                     name) != None:

                            logging.info('reading image: %s' % name)

                            _filepath = zfile.extract(
                                name, sys_config.preproc_folder_hcp
                            )  # extract the image filepath

                            _patname = name[:name.find(
                                '/')]  # extract the patient name

                            _img_data, _img_affine, _img_header = utils.load_nii(
                                _filepath)  # read the 3d image

                            _img_data = _img_data[
                                x_start:x_end, y_start:y_end, z_start:
                                z_end]  # discard some pixels as they are always zero.

                            _img_data = utils.normalise_image(
                                _img_data, norm_type='div_by_max'
                            )  # normalise the image (volume wise)

                            savepath = sys_config.preproc_folder_hcp + _patname + '/preprocessed_image' + protocol + '.nii'  # save the pre-processed image
                            utils.save_nii(savepath, _img_data, _img_affine,
                                           _img_header)

                            images.append(
                                _img_data
                            )  # append to the list of all images, affines and patient names
                            affines.append(_img_affine)
                            patnames.append(_patname)

                        # ========================
                        # get the segmentation mask
                        # ========================
                        if re.search(
                                'aparc.aseg', name
                        ) != None:  # segmentation mask with ~100 classes

                            if re.search('T1wDividedByT2w_', name) == None:

                                logging.info('reading mask: %s' % name)

                                _segpath = zfile.extract(
                                    name, sys_config.preproc_folder_hcp
                                )  # extract the segmentation mask

                                _patname = name[:name.find(
                                    '/')]  # extract the patient name

                                _seg_data, _seg_affine, _seg_header = utils.load_nii(
                                    _segpath)  # read the segmentation mask

                                _seg_data = _seg_data[
                                    x_start:x_end, y_start:y_end, z_start:
                                    z_end]  # discard some pixels as they are always zero.

                                _seg_data = utils.group_segmentation_classes(
                                    _seg_data
                                )  # group the segmentation classes as required

                                savepath = sys_config.preproc_folder_hcp + _patname + '/preprocessed_gt15.nii'  # save the pre-processed segmentation ground truth
                                utils.save_nii(savepath, _seg_data,
                                               _seg_affine, _seg_header)

                                masks.append(
                                    _seg_data
                                )  # append to the list of all masks

    # ========================
    # convert the lists to arrays
    # ========================
    images = np.array(images)
    affines = np.array(affines)
    patnames = np.array(patnames)
    masks = np.array(masks, dtype='uint8')

    # ========================
    # merge along the y-zis to get a stack of x-z slices, for the images as well as the masks
    # ========================
    images = images.swapaxes(1, 2)
    images = images.reshape(-1, images.shape[2], images.shape[3])
    masks = masks.swapaxes(1, 2)
    masks = masks.reshape(-1, masks.shape[2], masks.shape[3])

    # ========================
    # save the processed images and masks so that they can be directly read the next time
    # make appropriate filenames according to the requested indices of training, validation and test images
    # ========================
    logging.info('Saving pre-processed files...')
    config_details = '%sfrom%dto%d_' % (protocol, idx_start, idx_end)
    filepath_images = preproc_folder + config_details + 'images_2d.npy'
    filepath_masks = preproc_folder + config_details + 'annotations15_2d.npy'
    filepath_affine = preproc_folder + config_details + 'affines.npy'
    filepath_patnames = preproc_folder + config_details + 'patnames.npy'
    np.save(filepath_images, images)
    np.save(filepath_masks, masks)
    np.save(filepath_affine, affines)
    np.save(filepath_patnames, patnames)

    return images, masks, affines, patnames
Пример #28
0
def prepare_data(input_folder,
                 preproc_folder, # bias corrected images will be saved here already
                 output_file,
                 size,
                 target_resolution,
                 cv_fold_num):

    # =======================
    # create the hdf5 file where everything will be written
    # =======================
    hdf5_file = h5py.File(output_file, "w")

    # =======================
    # read all the images and count the number of slices along the append axis (the one with the lowest resolution)
    # =======================
    logging.info('Counting files and parsing meta data...')    
    # using the bias corrected images in the preproc folder for this step
    num_slices, patient_ids_list = count_slices_and_patient_ids_list(preproc_folder,
                                                                     cv_fold_number = cv_fold_num)
        
    # =======================
    # set the number of slices according to what has been found from the previous function
    # =======================
    nx, ny = size
    n_test = num_slices['test']
    n_train = num_slices['train']
    n_val = num_slices['validation']

    # =======================
    # Create datasets for images and masks
    # =======================
    data = {}
    for tt, num_points in zip(['test', 'train', 'validation'], [n_test, n_train, n_val]):

        if num_points > 0:
            data['images_%s' % tt] = hdf5_file.create_dataset("images_%s" % tt, [num_points] + list(size), dtype=np.float32)
            data['masks_%s' % tt] = hdf5_file.create_dataset("masks_%s" % tt, [num_points] + list(size), dtype=np.uint8)

    mask_list = {'test': [], 'train': [], 'validation': []}
    img_list = {'test': [], 'train': [], 'validation': []}
    nx_list = {'test': [], 'train': [], 'validation': []}
    ny_list = {'test': [], 'train': [], 'validation': []}
    nz_list = {'test': [], 'train': [], 'validation': []}
    px_list = {'test': [], 'train': [], 'validation': []}
    py_list = {'test': [], 'train': [], 'validation': []}
    pz_list = {'test': [], 'train': [], 'validation': []}
    pat_names_list = {'test': [], 'train': [], 'validation': []}              
                
    # =======================
    # read data of each subject, preprocess it and write to the hdf5 file
    # =======================
    logging.info('Parsing image files')
    for train_test in ['test', 'train', 'validation']:

        write_buffer = 0
        counter_from = 0
        patient_counter = 0
        
        for patient_id in patient_ids_list[train_test]:
            
            filepath_orig_mhd_format = input_folder + 'Case' + patient_id + '.mhd'
            filepath_orig_nii_format = preproc_folder + 'Case' + patient_id + '.nii.gz'
            filepath_bias_corrected_nii_format = preproc_folder + 'Case' + patient_id + '_n4.nii.gz'
            filepath_seg_nii_format = preproc_folder + 'Case' + patient_id + '_segmentation.nii.gz'

            patient_counter += 1
            pat_names_list[train_test].append('case' + patient_id)

            logging.info('================================')
            logging.info('Doing: %s' % filepath_orig_mhd_format)
            
            # ================================    
            # read the original mhd image, in order to extract pixel resolution information
            # ================================    
            img_mhd = sitk.ReadImage(filepath_orig_mhd_format)
            pixel_size = img_mhd.GetSpacing()
            px_list[train_test].append(float(pixel_size[0]))
            py_list[train_test].append(float(pixel_size[1]))
            pz_list[train_test].append(float(pixel_size[2]))

            # ================================    
            # read bias corrected image
            # ================================    
            img = utils.load_nii(filepath_bias_corrected_nii_format)[0]

            # ================================    
            # normalize the image
            # ================================    
            img = utils.normalise_image(img, norm_type='div_by_max')

            # ================================    
            # read the labels
            # ================================    
            mask = utils.load_nii(filepath_seg_nii_format)[0]            
            
            # ================================    
            # skimage io with simple ITKplugin was used to read the images in the convert_to_nii_and_correct_bias_field function.
            # this lead to the arrays being read as z-x-y
            # move the axes appropriately, so that the resolution read above is correct for the corresponding axes.
            # ================================    
            img = np.swapaxes(np.swapaxes(img, 0, 1), 1, 2)
            mask = np.swapaxes(np.swapaxes(mask, 0, 1), 1, 2)
            
            # ================================    
            # write to the dimensions now
            # ================================    
            nx_list[train_test].append(mask.shape[0])
            ny_list[train_test].append(mask.shape[1])
            nz_list[train_test].append(mask.shape[2])

            print('mask.shape')
            print(mask.shape)
            print('img.shape')
            print(img.shape)
            
            ### PROCESSING LOOP FOR SLICE-BY-SLICE 2D DATA ###################
            scale_vector = [pixel_size[0] / target_resolution[0],
                            pixel_size[1] / target_resolution[1]]

            for zz in range(img.shape[2]):

                slice_img = np.squeeze(img[:, :, zz])
                slice_rescaled = transform.rescale(slice_img,
                                                   scale_vector,
                                                   order=1,
                                                   preserve_range=True,
                                                   multichannel=False,
                                                   mode = 'constant')

                slice_mask = np.squeeze(mask[:, :, zz])
                mask_rescaled = transform.rescale(slice_mask,
                                                  scale_vector,
                                                  order=0,
                                                  preserve_range=True,
                                                  multichannel=False,
                                                  mode='constant')

                slice_cropped = utils.crop_or_pad_slice_to_size(slice_rescaled, nx, ny)
                mask_cropped = utils.crop_or_pad_slice_to_size(mask_rescaled, nx, ny)

                img_list[train_test].append(slice_cropped)
                mask_list[train_test].append(mask_cropped)

                write_buffer += 1

                # Writing needs to happen inside the loop over the slices
                if write_buffer >= MAX_WRITE_BUFFER:

                    counter_to = counter_from + write_buffer
                    _write_range_to_hdf5(data, train_test, img_list, mask_list, counter_from, counter_to)
                    _release_tmp_memory(img_list, mask_list, train_test)

                    # reset stuff for next iteration
                    counter_from = counter_to
                    write_buffer = 0


        logging.info('Writing remaining data')
        counter_to = counter_from + write_buffer

        _write_range_to_hdf5(data, train_test, img_list, mask_list, counter_from, counter_to)
        _release_tmp_memory(img_list, mask_list, train_test)

    # Write the small datasets
    for tt in ['test', 'train', 'validation']:
        hdf5_file.create_dataset('nx_%s' % tt, data=np.asarray(nx_list[tt], dtype=np.uint16))
        hdf5_file.create_dataset('ny_%s' % tt, data=np.asarray(ny_list[tt], dtype=np.uint16))
        hdf5_file.create_dataset('nz_%s' % tt, data=np.asarray(nz_list[tt], dtype=np.uint16))
        hdf5_file.create_dataset('px_%s' % tt, data=np.asarray(px_list[tt], dtype=np.float32))
        hdf5_file.create_dataset('py_%s' % tt, data=np.asarray(py_list[tt], dtype=np.float32))
        hdf5_file.create_dataset('pz_%s' % tt, data=np.asarray(pz_list[tt], dtype=np.float32))
        hdf5_file.create_dataset('patnames_%s' % tt, data=np.asarray(pat_names_list[tt], dtype="S10"))
    
    # After test train loop:
    hdf5_file.close()
def prepare_data(input_folder, output_file, mode, size, target_resolution):
    '''
    Main function that prepares a dataset from the raw challenge data to an hdf5 dataset
    '''

    assert (mode in ['2D', '3D']), 'Unknown mode: %s' % mode
    if mode == '2D' and not len(size) == 2:
        raise AssertionError('Inadequate number of size parameters')
    if mode == '3D' and not len(size) == 3:
        raise AssertionError('Inadequate number of size parameters')
    if mode == '2D' and not len(target_resolution) == 2:
        raise AssertionError(
            'Inadequate number of target resolution parameters')
    if mode == '3D' and not len(target_resolution) == 3:
        raise AssertionError(
            'Inadequate number of target resolution parameters')

    hdf5_file = h5py.File(output_file, "w")

    diag_list = {'test': [], 'train': []}
    height_list = {'test': [], 'train': []}
    weight_list = {'test': [], 'train': []}
    patient_id_list = {'test': [], 'train': []}
    cardiac_phase_list = {'test': [], 'train': []}

    file_list = {'test': [], 'train': []}
    num_slices = {'test': 0, 'train': 0}

    logging.info('Counting files and parsing meta data...')

    for folder in os.listdir(input_folder):

        folder_path = os.path.join(input_folder, folder)

        if os.path.isdir(folder_path):

            train_test = 'test' if (int(folder[-3:]) % 5 == 0) else 'train'

            infos = {}
            for line in open(os.path.join(folder_path, 'Info.cfg')):
                label, value = line.split(':')
                infos[label] = value.rstrip('\n').lstrip(' ')

            patient_id = folder.lstrip('patient')

            for file in glob.glob(
                    os.path.join(folder_path, 'patient???_frame??.nii.gz')):

                file_list[train_test].append(file)

                # diag_list[train_test].append(diagnosis_to_int(infos['Group']))
                diag_list[train_test].append(diagnosis_dict[infos['Group']])
                weight_list[train_test].append(infos['Weight'])
                height_list[train_test].append(infos['Height'])

                patient_id_list[train_test].append(patient_id)

                systole_frame = int(infos['ES'])
                diastole_frame = int(infos['ED'])

                file_base = file.split('.')[0]
                frame = int(file_base.split('frame')[-1])
                if frame == systole_frame:
                    cardiac_phase_list[train_test].append(1)  # 1 == systole
                elif frame == diastole_frame:
                    cardiac_phase_list[train_test].append(2)  # 2 == diastole
                else:
                    cardiac_phase_list[train_test].append(
                        0)  # 0 means other phase

                nifty_img = nib.load(file)
                num_slices[train_test] += nifty_img.shape[2]

    # Write the small datasets
    for tt in ['test', 'train']:
        hdf5_file.create_dataset('diagnosis_%s' % tt,
                                 data=np.asarray(diag_list[tt],
                                                 dtype=np.uint8))
        hdf5_file.create_dataset('weight_%s' % tt,
                                 data=np.asarray(weight_list[tt],
                                                 dtype=np.float32))
        hdf5_file.create_dataset('height_%s' % tt,
                                 data=np.asarray(height_list[tt],
                                                 dtype=np.float32))
        hdf5_file.create_dataset('patient_id_%s' % tt,
                                 data=np.asarray(patient_id_list[tt],
                                                 dtype=np.uint8))
        hdf5_file.create_dataset('cardiac_phase_%s' % tt,
                                 data=np.asarray(cardiac_phase_list[tt],
                                                 dtype=np.uint8))

    if mode == '3D':
        nx, ny, nz_max = size
        n_train = len(file_list['train'])
        n_test = len(file_list['test'])

    elif mode == '2D':
        nx, ny = size
        n_test = num_slices['test']
        n_train = num_slices['train']

    else:
        raise AssertionError('Wrong mode setting. This should never happen.')

    # Create datasets for images and masks
    data = {}

    for tt, num_points in zip(['test', 'train'], [n_test, n_train]):
        data['images_%s' % tt] = hdf5_file.create_dataset(
            "images_%s" % tt, [num_points] + list(size), dtype=np.float32)
        data['masks_%s' % tt] = hdf5_file.create_dataset(
            "masks_%s" % tt, [num_points] + list(size), dtype=np.uint8)

    mask_list = {'test': [], 'train': []}
    img_list = {'test': [], 'train': []}

    logging.info('Parsing image files')

    for train_test in ['test', 'train']:

        write_buffer = 0
        counter_from = 0

        for file in file_list[train_test]:

            logging.info(
                '-----------------------------------------------------------')
            logging.info('Doing: %s' % file)

            file_base = file.split('.nii.gz')[0]
            file_mask = file_base + '_gt.nii.gz'

            img_dat = utils.load_nii(file)
            mask_dat = utils.load_nii(file_mask)

            img = img_dat[0].copy()
            mask = mask_dat[0].copy()

            img = image_utils.normalise_image(img)

            pixel_size = (img_dat[2].structarr['pixdim'][1],
                          img_dat[2].structarr['pixdim'][2],
                          img_dat[2].structarr['pixdim'][3])

            logging.info('Pixel size:')
            logging.info(pixel_size)

            ### PROCESSING LOOP FOR 3D DATA ################################
            if mode == '3D':

                scale_vector = [
                    pixel_size[0] / target_resolution[0],
                    pixel_size[1] / target_resolution[1],
                    pixel_size[2] / target_resolution[2]
                ]

                img_scaled = transform.rescale(img,
                                               scale_vector,
                                               order=1,
                                               preserve_range=True,
                                               multichannel=False,
                                               mode='constant')
                mask_scaled = transform.rescale(mask,
                                                scale_vector,
                                                order=0,
                                                preserve_range=True,
                                                multichannel=False,
                                                mode='constant')

                slice_vol = np.zeros((nx, ny, nz_max), dtype=np.float32)
                mask_vol = np.zeros((nx, ny, nz_max), dtype=np.uint8)

                nz_curr = img_scaled.shape[2]
                stack_from = (nz_max - nz_curr) // 2

                if stack_from < 0:
                    raise AssertionError(
                        'nz_max is too small for the chosen through plane resolution. Consider changing'
                        'the size or the target resolution in the through-plane.'
                    )

                for zz in range(nz_curr):

                    slice_rescaled = img_scaled[:, :, zz]
                    mask_rescaled = mask_scaled[:, :, zz]

                    slice_cropped = crop_or_pad_slice_to_size(
                        slice_rescaled, nx, ny)
                    mask_cropped = crop_or_pad_slice_to_size(
                        mask_rescaled, nx, ny)

                    slice_vol[:, :, stack_from] = slice_cropped
                    mask_vol[:, :, stack_from] = mask_cropped

                    stack_from += 1

                img_list[train_test].append(slice_vol)
                mask_list[train_test].append(mask_vol)

                write_buffer += 1

                if write_buffer >= MAX_WRITE_BUFFER:

                    counter_to = counter_from + write_buffer
                    _write_range_to_hdf5(data, train_test, img_list, mask_list,
                                         counter_from, counter_to)
                    _release_tmp_memory(img_list, mask_list, train_test)

                    # reset stuff for next iteration
                    counter_from = counter_to
                    write_buffer = 0

            ### PROCESSING LOOP FOR SLICE-BY-SLICE 2D DATA ###################
            elif mode == '2D':

                scale_vector = [
                    pixel_size[0] / target_resolution[0],
                    pixel_size[1] / target_resolution[1]
                ]

                for zz in range(img.shape[2]):

                    slice_img = np.squeeze(img[:, :, zz])
                    slice_rescaled = transform.rescale(slice_img,
                                                       scale_vector,
                                                       order=1,
                                                       preserve_range=True,
                                                       multichannel=False,
                                                       mode='constant')

                    slice_mask = np.squeeze(mask[:, :, zz])
                    mask_rescaled = transform.rescale(slice_mask,
                                                      scale_vector,
                                                      order=0,
                                                      preserve_range=True,
                                                      multichannel=False,
                                                      mode='constant')

                    slice_cropped = crop_or_pad_slice_to_size(
                        slice_rescaled, nx, ny)
                    mask_cropped = crop_or_pad_slice_to_size(
                        mask_rescaled, nx, ny)

                    img_list[train_test].append(slice_cropped)
                    mask_list[train_test].append(mask_cropped)

                    write_buffer += 1

                    # Writing needs to happen inside the loop over the slices
                    if write_buffer >= MAX_WRITE_BUFFER:

                        counter_to = counter_from + write_buffer
                        _write_range_to_hdf5(data, train_test, img_list,
                                             mask_list, counter_from,
                                             counter_to)
                        _release_tmp_memory(img_list, mask_list, train_test)

                        # reset stuff for next iteration
                        counter_from = counter_to
                        write_buffer = 0

        # after file loop: Write the remaining data

        logging.info('Writing remaining data')
        counter_to = counter_from + write_buffer

        _write_range_to_hdf5(data, train_test, img_list, mask_list,
                             counter_from, counter_to)
        _release_tmp_memory(img_list, mask_list, train_test)

    # After test train loop:
    hdf5_file.close()
def main(input_folder,
         output_folder,
         model_path,
         exp_config,
         do_postprocessing=False,
         gt_exists=True):

    # Get Data
    data_loader = data_switch(exp_config.data_identifier)
    data = data_loader(exp_config)

    # Make and restore vagan model
    segmenter_model = segmenter(
        exp_config=exp_config, data=data,
        fixed_batch_size=1)  # CRF model requires fixed batch size
    segmenter_model.load_weights(model_path, type='best_dice')

    total_time = 0
    total_volumes = 0

    dice_list = []
    assd_list = []
    hd_list = []

    for folder in os.listdir(input_folder):

        folder_path = os.path.join(input_folder, folder)

        if os.path.isdir(folder_path):

            infos = {}
            for line in open(os.path.join(folder_path, 'Info.cfg')):
                label, value = line.split(':')
                infos[label] = value.rstrip('\n').lstrip(' ')

            patient_id = folder.lstrip('patient')

            if not int(patient_id) % 5 == 0:
                continue

            ED_frame = int(infos['ED'])
            ES_frame = int(infos['ES'])

            for file in glob.glob(
                    os.path.join(folder_path, 'patient???_frame??.nii.gz')):

                logging.info(' ----- Doing image: -------------------------')
                logging.info('Doing: %s' % file)
                logging.info(' --------------------------------------------')

                file_base = file.split('.nii.gz')[0]

                frame = int(file_base.split('frame')[-1])
                img, img_affine, img_header = utils.load_nii(file)
                img = utils.normalise_image(img)
                zooms = img_header.get_zooms()

                if gt_exists:
                    file_mask = file_base + '_gt.nii.gz'
                    mask, mask_affine, mask_header = utils.load_nii(file_mask)

                start_time = time.time()

                if exp_config.dimensionality_mode == '2D':

                    pixel_size = (img_header.structarr['pixdim'][1],
                                  img_header.structarr['pixdim'][2])
                    scale_vector = (pixel_size[0] /
                                    exp_config.target_resolution[0],
                                    pixel_size[1] /
                                    exp_config.target_resolution[1])

                    predictions = []

                    nx, ny = exp_config.image_size

                    for zz in range(img.shape[2]):

                        slice_img = np.squeeze(img[:, :, zz])
                        slice_rescaled = transform.rescale(slice_img,
                                                           scale_vector,
                                                           order=1,
                                                           preserve_range=True,
                                                           multichannel=False,
                                                           mode='constant')

                        x, y = slice_rescaled.shape

                        x_s = (x - nx) // 2
                        y_s = (y - ny) // 2
                        x_c = (nx - x) // 2
                        y_c = (ny - y) // 2

                        # Crop section of image for prediction
                        if x > nx and y > ny:
                            slice_cropped = slice_rescaled[x_s:x_s + nx,
                                                           y_s:y_s + ny]
                        else:
                            slice_cropped = np.zeros((nx, ny))
                            if x <= nx and y > ny:
                                slice_cropped[x_c:x_c +
                                              x, :] = slice_rescaled[:,
                                                                     y_s:y_s +
                                                                     ny]
                            elif x > nx and y <= ny:
                                slice_cropped[:, y_c:y_c +
                                              y] = slice_rescaled[x_s:x_s +
                                                                  nx, :]
                            else:
                                slice_cropped[x_c:x_c + x, y_c:y_c +
                                              y] = slice_rescaled[:, :]

                        # GET PREDICTION
                        network_input = np.float32(
                            np.tile(np.reshape(slice_cropped, (nx, ny, 1)),
                                    (1, 1, 1, 1)))
                        mask_out, softmax = segmenter_model.predict(
                            network_input)

                        prediction_cropped = np.squeeze(softmax[0, ...])

                        # ASSEMBLE BACK THE SLICES
                        slice_predictions = np.zeros(
                            (x, y, exp_config.nlabels))
                        # insert cropped region into original image again
                        if x > nx and y > ny:
                            slice_predictions[x_s:x_s + nx, y_s:y_s +
                                              ny, :] = prediction_cropped
                        else:
                            if x <= nx and y > ny:
                                slice_predictions[:, y_s:y_s +
                                                  ny, :] = prediction_cropped[
                                                      x_c:x_c + x, :, :]
                            elif x > nx and y <= ny:
                                slice_predictions[
                                    x_s:x_s +
                                    nx, :, :] = prediction_cropped[:, y_c:y_c +
                                                                   y, :]
                            else:
                                slice_predictions[:, :, :] = prediction_cropped[
                                    x_c:x_c + x, y_c:y_c + y, :]

                        # RESCALING ON THE LOGITS
                        if gt_exists:
                            prediction = transform.resize(
                                slice_predictions,
                                (mask.shape[0], mask.shape[1],
                                 exp_config.nlabels),
                                order=1,
                                preserve_range=True,
                                mode='constant')
                        else:  # This can occasionally lead to wrong volume size, therefore if gt_exists
                            # we use the gt mask size for resizing.
                            prediction = transform.rescale(
                                slice_predictions, (1.0 / scale_vector[0],
                                                    1.0 / scale_vector[1], 1),
                                order=1,
                                preserve_range=True,
                                multichannel=False,
                                mode='constant')

                        prediction = np.uint8(np.argmax(prediction, axis=-1))
                        # import matplotlib.pyplot as plt
                        # fig = plt.Figure()
                        # for ii in range(3):
                        #     plt.subplot(1, 3, ii + 1)
                        #     plt.imshow(np.squeeze(prediction))
                        # plt.show()

                        predictions.append(prediction)

                    prediction_arr = np.transpose(
                        np.asarray(predictions, dtype=np.uint8), (1, 2, 0))

                elif exp_config.dimensionality_mode == '3D':

                    nx, ny, nz = exp_config.image_size

                    pixel_size = (img_header.structarr['pixdim'][1],
                                  img_header.structarr['pixdim'][2],
                                  img_header.structarr['pixdim'][3])

                    scale_vector = (pixel_size[0] /
                                    exp_config.target_resolution[0],
                                    pixel_size[1] /
                                    exp_config.target_resolution[1],
                                    pixel_size[2] /
                                    exp_config.target_resolution[2])

                    vol_scaled = transform.rescale(img,
                                                   scale_vector,
                                                   order=1,
                                                   preserve_range=True,
                                                   multichannel=False,
                                                   mode='constant')

                    nz_max = exp_config.image_size[2]
                    slice_vol = np.zeros((nx, ny, nz_max), dtype=np.float32)

                    nz_curr = vol_scaled.shape[2]
                    stack_from = (nz_max - nz_curr) // 2
                    stack_counter = stack_from

                    x, y, z = vol_scaled.shape

                    x_s = (x - nx) // 2
                    y_s = (y - ny) // 2
                    x_c = (nx - x) // 2
                    y_c = (ny - y) // 2

                    for zz in range(nz_curr):

                        slice_rescaled = vol_scaled[:, :, zz]

                        if x > nx and y > ny:
                            slice_cropped = slice_rescaled[x_s:x_s + nx,
                                                           y_s:y_s + ny]
                        else:
                            slice_cropped = np.zeros((nx, ny))
                            if x <= nx and y > ny:
                                slice_cropped[x_c:x_c +
                                              x, :] = slice_rescaled[:,
                                                                     y_s:y_s +
                                                                     ny]
                            elif x > nx and y <= ny:
                                slice_cropped[:, y_c:y_c +
                                              y] = slice_rescaled[x_s:x_s +
                                                                  nx, :]

                            else:
                                slice_cropped[x_c:x_c + x, y_c:y_c +
                                              y] = slice_rescaled[:, :]

                        slice_vol[:, :, stack_counter] = slice_cropped
                        stack_counter += 1

                    stack_to = stack_counter

                    network_input = np.float32(
                        np.reshape(slice_vol, (1, nx, ny, nz_max, 1)))
                    start_time = time.time()
                    mask_out, softmax = segmenter_model.predict(network_input)
                    logging.info('Classified 3D: %f secs' %
                                 (time.time() - start_time))

                    prediction_nzs = mask_out[0, :, :, stack_from:
                                              stack_to]  # non-zero-slices

                    if not prediction_nzs.shape[2] == nz_curr:
                        raise ValueError('sizes mismatch')

                    # ASSEMBLE BACK THE SLICES
                    prediction_scaled = np.zeros(
                        vol_scaled.shape)  # last dim is for logits classes

                    # insert cropped region into original image again
                    if x > nx and y > ny:
                        prediction_scaled[x_s:x_s + nx,
                                          y_s:y_s + ny, :] = prediction_nzs
                    else:
                        if x <= nx and y > ny:
                            prediction_scaled[:, y_s:y_s +
                                              ny, :] = prediction_nzs[x_c:x_c +
                                                                      x, :, :]
                        elif x > nx and y <= ny:
                            prediction_scaled[
                                x_s:x_s +
                                nx, :, :] = prediction_nzs[:, y_c:y_c + y, :]
                        else:
                            prediction_scaled[:, :, :] = prediction_nzs[
                                x_c:x_c + x, y_c:y_c + y, :]

                    logging.info('Prediction_scaled mean %f' %
                                 (np.mean(prediction_scaled)))

                    prediction = transform.resize(
                        prediction_scaled,
                        (mask.shape[0], mask.shape[1], mask.shape[2], 1),
                        order=1,
                        preserve_range=True,
                        mode='constant')
                    prediction = np.argmax(prediction, axis=-1)
                    prediction_arr = np.asarray(prediction, dtype=np.uint8)

                # This is the same for 2D and 3D again
                if do_postprocessing:
                    prediction_arr = utils.keep_largest_connected_components(
                        prediction_arr)

                elapsed_time = time.time() - start_time
                total_time += elapsed_time
                total_volumes += 1

                logging.info('Evaluation of volume took %f secs.' %
                             elapsed_time)

                if frame == ED_frame:
                    frame_suffix = '_ED'
                elif frame == ES_frame:
                    frame_suffix = '_ES'
                else:
                    raise ValueError(
                        'Frame doesnt correspond to ED or ES. frame = %d, ED = %d, ES = %d'
                        % (frame, ED_frame, ES_frame))

                # Save prediced mask
                out_file_name = os.path.join(
                    output_folder, 'prediction',
                    'patient' + patient_id + frame_suffix + '.nii.gz')
                if gt_exists:
                    out_affine = mask_affine
                    out_header = mask_header
                else:
                    out_affine = img_affine
                    out_header = img_header

                logging.info('saving to: %s' % out_file_name)
                utils.save_nii(out_file_name, prediction_arr, out_affine,
                               out_header)

                # Save image data to the same folder for convenience
                image_file_name = os.path.join(
                    output_folder, 'image',
                    'patient' + patient_id + frame_suffix + '.nii.gz')
                logging.info('saving to: %s' % image_file_name)
                utils.save_nii(image_file_name, img, out_affine, out_header)

                if gt_exists:

                    # Save GT image
                    gt_file_name = os.path.join(
                        output_folder, 'ground_truth',
                        'patient' + patient_id + frame_suffix + '.nii.gz')
                    logging.info('saving to: %s' % gt_file_name)
                    utils.save_nii(gt_file_name, mask, out_affine, out_header)

                    # Save difference mask between predictions and ground truth
                    difference_mask = np.where(
                        np.abs(prediction_arr - mask) > 0, [1], [0])
                    difference_mask = np.asarray(difference_mask,
                                                 dtype=np.uint8)
                    diff_file_name = os.path.join(
                        output_folder, 'difference',
                        'patient' + patient_id + frame_suffix + '.nii.gz')
                    logging.info('saving to: %s' % diff_file_name)
                    utils.save_nii(diff_file_name, difference_mask, out_affine,
                                   out_header)

                # calculate metrics
                y_ = prediction_arr
                y = mask

                per_lbl_dice = []
                per_lbl_assd = []
                per_lbl_hd = []

                for lbl in [3, 1, 2]:  #range(exp_config.nlabels):

                    binary_pred = (y_ == lbl) * 1
                    binary_gt = (y == lbl) * 1

                    if np.sum(binary_gt) == 0 and np.sum(binary_pred) == 0:
                        per_lbl_dice.append(1)
                        per_lbl_assd.append(0)
                        per_lbl_hd.append(0)
                    elif np.sum(binary_pred) > 0 and np.sum(
                            binary_gt) == 0 or np.sum(
                                binary_pred) == 0 and np.sum(binary_gt) > 0:
                        logging.warning(
                            'Structure missing in either GT (x)or prediction. ASSD and HD will not be accurate.'
                        )
                        per_lbl_dice.append(0)
                        per_lbl_assd.append(1)
                        per_lbl_hd.append(1)
                    else:
                        per_lbl_dice.append(dc(binary_pred, binary_gt))
                        per_lbl_assd.append(
                            assd(binary_pred, binary_gt, voxelspacing=zooms))
                        per_lbl_hd.append(
                            hd(binary_pred, binary_gt, voxelspacing=zooms))

                dice_list.append(per_lbl_dice)
                assd_list.append(per_lbl_assd)
                hd_list.append(per_lbl_hd)

    logging.info('Average time per volume: %f' % (total_time / total_volumes))

    dice_arr = np.asarray(dice_list)
    assd_arr = np.asarray(assd_list)
    hd_arr = np.asarray(hd_list)

    mean_per_lbl_dice = dice_arr.mean(axis=0)
    mean_per_lbl_assd = assd_arr.mean(axis=0)
    mean_per_lbl_hd = hd_arr.mean(axis=0)

    logging.info('Dice')
    logging.info(mean_per_lbl_dice)
    logging.info(np.mean(mean_per_lbl_dice))
    logging.info('ASSD')
    logging.info(mean_per_lbl_assd)
    logging.info(np.mean(mean_per_lbl_assd))
    logging.info('HD')
    logging.info(mean_per_lbl_hd)
    logging.info(np.mean(mean_per_lbl_hd))