Beispiel #1
0
def load_without_size_preprocessing(input_folder, site_name, idx, depth):

    # ========================
    # read the filenames
    # ========================
    filenames = sorted(glob.glob(input_folder + site_name + '/*/'))

    # ==================
    # get file paths
    # ==================
    patient_name, image_path, label_path = get_image_and_label_paths(
        filenames[idx])

    # ============
    # read the image and normalize it to be between 0 and 1
    # ============
    image, _, image_hdr = utils.load_nii(image_path)
    image = np.swapaxes(
        image, 1, 2
    )  # swap axes 1 and 2 -> this allows appending along axis 2, as in other datasets

    # ==================
    # read the label file
    # ==================
    label, _, _ = utils.load_nii(label_path)
    label = np.swapaxes(
        label, 1, 2
    )  # swap axes 1 and 2 -> this allows appending along axis 2, as in other datasets
    label = utils.group_segmentation_classes(
        label)  # group the segmentation classes as required

    # ============
    # create a segmentation mask and use it to get rid of the skull in the image
    # ============
    label_mask = np.copy(label)
    label_mask[label > 0] = 1
    image = image * label_mask

    # ==================
    # crop out some portion of the image, which are all zeros (rough registration via visual inspection)
    # ==================
    if site_name is 'CALTECH':
        image = image[:, 80:, :]
        label = label[:, 80:, :]
    elif site_name is 'STANFORD':
        image, label = center_image_and_label(image, label)

    # ==================
    # crop volume along z axis (as there are several zeros towards the ends)
    # ==================
    image = utils.crop_or_pad_volume_to_size_along_z(image, depth)
    label = utils.crop_or_pad_volume_to_size_along_z(label, depth)

    # ==================
    # normalize the image
    # ==================
    image = utils.normalise_image(image, norm_type='div_by_max')

    return image, label
Beispiel #2
0
def load_without_size_preprocessing(input_folder, idx, protocol,
                                    preprocessing_folder, depth):

    # ========================
    # read the filenames
    # ========================
    filenames = sorted(glob.glob(input_folder + '*.zip'))

    # ==================
    # get file paths
    # ==================
    patient_name, image_path, label_path = get_image_and_label_paths(
        filenames[idx], protocol, preprocessing_folder)

    # ============
    # read the image and normalize it to be between 0 and 1
    # ============
    image, _, image_hdr = utils.load_nii(image_path)
    image = np.swapaxes(
        image, 1, 2
    )  # swap axes 1 and 2 -> this allows appending along axis 2, as in other datasets
    image = utils.normalise_image(image, norm_type='div_by_max')

    # ==================
    # read the label file
    # ==================
    label, _, _ = utils.load_nii(label_path)
    label = np.swapaxes(
        label, 1, 2
    )  # swap axes 1 and 2 -> this allows appending along axis 2, as in other datasets
    label = utils.group_segmentation_classes(
        label)  # group the segmentation classes as required

    # ==================
    # crop volume along z axis (as there are several zeros towards the ends)
    # ==================
    image = utils.crop_or_pad_volume_to_size_along_z(image, depth)
    label = utils.crop_or_pad_volume_to_size_along_z(label, depth)

    return image, label
def prepare_data(input_folder, output_file, idx_start, idx_end, protocol, size,
                 target_resolution, preprocessing_folder):
    # ========================
    # read the filenames
    # ========================
    filenames = sorted(glob.glob(input_folder + '*.zip'))
    logging.info('Number of images in the dataset: %s' % str(len(filenames)))

    # =======================
    # create a hdf5 file
    # =======================
    # hdf5_file = h5py.File(output_file, "w")
    #
    # # ===============================
    # # Create datasets for images and labels
    # # ===============================
    # data = {}
    # num_subjects = idx_end - idx_start
    #
    # data['images'] = hdf5_file.create_dataset("images", [num_subjects] + list(size), dtype=np.float32)
    # data['labels'] = hdf5_file.create_dataset("labels", [num_subjects] + list(size), dtype=np.uint8)
    #
    # # ===============================
    # initialize lists
    # ===============================
    label_list = []
    image_list = []
    nx_list = []
    ny_list = []
    nz_list = []
    px_list = []
    py_list = []
    pz_list = []
    pat_names_list = []

    # ===============================
    # initiate counter
    # ===============================
    patient_counter = 0

    # ===============================
    # iterate through the requested indices
    # ===============================
    for idx in range(idx_start, idx_end):
        logging.info('Volume {} of {}...'.format(idx, idx_end))

        # ==================
        # get file paths
        # ==================
        patient_name, image_path, label_path = get_image_and_label_paths(
            filenames[idx], protocol, preprocessing_folder)

        # ============
        # read the image and normalize it to be between 0 and 1
        # ============
        image, _, image_hdr = utils.load_nii(image_path)

        # ==================
        # read the label file
        # ==================
        label, _, _ = utils.load_nii(label_path)
        label = utils.group_segmentation_classes(
            label)  # group the segmentation classes as required

        # # ==================
        # # collect some header info.
        # # ==================
        # px_list.append(float(image_hdr.get_zooms()[0]))
        # py_list.append(float(image_hdr.get_zooms()[1]))
        # pz_list.append(float(image_hdr.get_zooms()[2]))
        # nx_list.append(image.shape[0])
        # ny_list.append(image.shape[1])
        # nz_list.append(image.shape[2])
        # pat_names_list.append(patient_name)

        # ==================
        # crop volume along all axes from the ends (as there are several zeros towards the ends)
        # ==================
        image = utils.crop_or_pad_volume_to_size_along_x(image, 256)
        label = utils.crop_or_pad_volume_to_size_along_x(label, 256)
        image = utils.crop_or_pad_volume_to_size_along_y(image, 256)
        label = utils.crop_or_pad_volume_to_size_along_y(label, 256)
        image = utils.crop_or_pad_volume_to_size_along_z(image, 256)
        label = utils.crop_or_pad_volume_to_size_along_z(label, 256)

        # ==================
        # normalize the image
        # ==================
        image_normalized = utils.normalise_image(image, norm_type='div_by_max')

        # ======================================================
        # rescale, crop / pad to make all images of the required size and resolution
        # ======================================================
        scale_vector = [
            image_hdr.get_zooms()[0] / target_resolution[0],
            image_hdr.get_zooms()[1] / target_resolution[1],
            image_hdr.get_zooms()[2] / target_resolution[2]
        ]

        image_rescaled = transform.rescale(image_normalized,
                                           scale_vector,
                                           order=1,
                                           preserve_range=True,
                                           multichannel=False,
                                           mode='constant')

        # label_onehot = utils.make_onehot(label, nlabels=15)
        #
        # label_onehot_rescaled = transform.rescale(label_onehot,
        #                                           scale_vector,
        #                                           order=1,
        #                                           preserve_range=True,
        #                                           multichannel=True,
        #                                           mode='constant')
        #
        # label_rescaled = np.argmax(label_onehot_rescaled, axis=-1)
        #
        # # ============
        # # the images and labels have been rescaled to the desired resolution.
        # # write them to the hdf5 file now.
        # # ============
        # image_list.append(image_rescaled)
        # label_list.append(label_rescaled)

        # ============
        # write to file
        # ============
        # image_rescaled
        volume_dir = os.path.join(preprocessing_folder,
                                  'volume_{:06d}'.format(idx))
        os.makedirs(volume_dir, exist_ok=True)
        for i in range(size[1]):
            slice_path = os.path.join(volume_dir,
                                      'slice_{:06d}.jpeg'.format(i))
            slice = image_rescaled[:, i, :] * 255
            image = Image.fromarray(slice.astype(np.uint8))
            image.save(slice_path)
def main():

    # ===================================
    # read the test images
    # ===================================
    test_dataset_name = exp_config.test_dataset

    if test_dataset_name is 'HCPT1':
        logging.info('Reading HCPT1 images...')
        logging.info('Data root directory: ' + sys_config.orig_data_root_hcp)

        image_depth = exp_config.image_depth_hcp
        idx_start = 50
        idx_end = 70

        data_brain_test = data_hcp.load_and_maybe_process_data(
            input_folder=sys_config.orig_data_root_hcp,
            preprocessing_folder=sys_config.preproc_folder_hcp,
            idx_start=idx_start,
            idx_end=idx_end,
            protocol='T1',
            size=exp_config.image_size,
            depth=image_depth,
            target_resolution=exp_config.target_resolution_brain)

    elif test_dataset_name is 'HCPT2':
        logging.info('Reading HCPT2 images...')
        logging.info('Data root directory: ' + sys_config.orig_data_root_hcp)

        image_depth = exp_config.image_depth_hcp
        idx_start = 50
        idx_end = 70

        data_brain_test = data_hcp.load_and_maybe_process_data(
            input_folder=sys_config.orig_data_root_hcp,
            preprocessing_folder=sys_config.preproc_folder_hcp,
            idx_start=idx_start,
            idx_end=idx_end,
            protocol='T2',
            size=exp_config.image_size,
            depth=image_depth,
            target_resolution=exp_config.target_resolution_brain)

    elif test_dataset_name is 'CALTECH':
        logging.info('Reading CALTECH images...')
        logging.info('Data root directory: ' +
                     sys_config.orig_data_root_abide + 'CALTECH/')

        image_depth = exp_config.image_depth_caltech
        idx_start = 16
        idx_end = 36

        data_brain_test = data_abide.load_and_maybe_process_data(
            input_folder=sys_config.orig_data_root_abide,
            preprocessing_folder=sys_config.preproc_folder_abide,
            site_name='CALTECH',
            idx_start=idx_start,
            idx_end=idx_end,
            protocol='T1',
            size=exp_config.image_size,
            depth=image_depth,
            target_resolution=exp_config.target_resolution_brain)

    elif test_dataset_name is 'STANFORD':
        logging.info('Reading STANFORD images...')
        logging.info('Data root directory: ' +
                     sys_config.orig_data_root_abide + 'STANFORD/')

        image_depth = exp_config.image_depth_stanford
        idx_start = 16
        idx_end = 36

        data_brain_test = data_abide.load_and_maybe_process_data(
            input_folder=sys_config.orig_data_root_abide,
            preprocessing_folder=sys_config.preproc_folder_abide,
            site_name='STANFORD',
            idx_start=idx_start,
            idx_end=idx_end,
            protocol='T1',
            size=exp_config.image_size,
            depth=image_depth,
            target_resolution=exp_config.target_resolution_brain)

    imts = data_brain_test['images']
    name_test_subjects = data_brain_test['patnames']
    num_test_subjects = imts.shape[0] // image_depth
    ids = np.arange(idx_start, idx_end)

    orig_data_res_x = data_brain_test['px'][:]
    orig_data_res_y = data_brain_test['py'][:]
    orig_data_res_z = data_brain_test['pz'][:]
    orig_data_siz_x = data_brain_test['nx'][:]
    orig_data_siz_y = data_brain_test['ny'][:]
    orig_data_siz_z = data_brain_test['nz'][:]

    # ================================
    # set the log directory
    # ================================
    if exp_config.normalize is True:
        log_dir = os.path.join(sys_config.log_root,
                               exp_config.expname_normalizer)
    else:
        log_dir = sys_config.log_root + 'i2l_mapper/' + exp_config.expname_i2l

    if exp_config.post_process is True:
        file_suffix = '_with_post_process_with_dae_runs' + str(
            exp_config.dae_post_process_runs)
    else:
        file_suffix = ''

    logging.info(log_dir)

    # ================================
    # open a text file for writing the mean dice scores for each subject that is evaluated
    # ================================
    results_file = open(
        log_dir + '/' + test_dataset_name + '_' + 'test' + file_suffix +
        '.txt', "w")
    results_file.write("================================== \n")
    results_file.write("Test results \n")

    # ================================================================
    # For each test image, load the best model and compute the dice with this model
    # ================================================================
    dice_per_label_per_subject = []
    hsd_per_label_per_subject = []

    for sub_num in range(num_test_subjects):

        subject_id_start_slice = np.sum(orig_data_siz_z[:sub_num])
        subject_id_end_slice = np.sum(orig_data_siz_z[:sub_num + 1])
        image = imts[subject_id_start_slice:subject_id_end_slice, :, :]

        # ==================================================================
        # setup logging
        # ==================================================================
        logging.basicConfig(level=logging.INFO,
                            format='%(asctime)s %(message)s')
        subject_name = str(name_test_subjects[sub_num])[2:-1]
        logging.info(
            '============================================================')
        logging.info('Subject id: %s' % sub_num)

        # ==================================================================
        # predict segmentation at the pre-processed resolution
        # ==================================================================
        predicted_labels, normalized_image = predict_segmentation(
            subject_name, image, exp_config.normalize, exp_config.post_process)

        # ==================================================================
        # read the original segmentation mask
        # ==================================================================
        if test_dataset_name is 'HCPT1':
            # image will be normalized to [0,1]
            image_orig, labels_orig = data_hcp.load_without_size_preprocessing(
                input_folder=sys_config.orig_data_root_hcp,
                idx=ids[sub_num],
                protocol='T1',
                preprocessing_folder=sys_config.preproc_folder_hcp,
                depth=image_depth)
            num_rotations = 0

        elif test_dataset_name is 'HCPT2':
            # image will be normalized to [0,1]
            image_orig, labels_orig = data_hcp.load_without_size_preprocessing(
                input_folder=sys_config.orig_data_root_hcp,
                idx=ids[sub_num],
                protocol='T2',
                preprocessing_folder=sys_config.preproc_folder_hcp,
                depth=image_depth)
            num_rotations = 0

        elif test_dataset_name is 'CALTECH':
            # image will be normalized to [0,1]
            image_orig, labels_orig = data_abide.load_without_size_preprocessing(
                input_folder=sys_config.orig_data_root_abide,
                site_name='CALTECH',
                idx=ids[sub_num],
                depth=image_depth)
            num_rotations = 0

        elif test_dataset_name is 'STANFORD':
            # image will be normalized to [0,1]
            image_orig, labels_orig = data_abide.load_without_size_preprocessing(
                input_folder=sys_config.orig_data_root_abide,
                site_name='STANFORD',
                idx=ids[sub_num],
                depth=image_depth)
            num_rotations = 0

        # ==================================================================
        # convert the predicitons back to original resolution
        # ==================================================================
        predicted_labels_orig_res_and_size = rescale_and_crop(
            predicted_labels,
            orig_data_res_x[sub_num],
            orig_data_res_y[sub_num],
            orig_data_siz_x[sub_num],
            orig_data_siz_y[sub_num],
            order_interpolation=0,
            num_rotations=num_rotations)

        normalized_image_orig_res_and_size = rescale_and_crop(
            normalized_image,
            orig_data_res_x[sub_num],
            orig_data_res_y[sub_num],
            orig_data_siz_x[sub_num],
            orig_data_siz_y[sub_num],
            order_interpolation=1,
            num_rotations=num_rotations)

        # ==================================================================
        # compute dice at the original resolution
        # ==================================================================
        dice_per_label_this_subject = met.f1_score(
            labels_orig.flatten(),
            predicted_labels_orig_res_and_size.flatten(),
            average=None)

        # ==================================================================
        # compute Hausforff distance at the original resolution
        # ==================================================================
        hsd_per_label_this_subject = utils.compute_surface_distance(
            y1=labels_orig,
            y2=predicted_labels_orig_res_and_size,
            nlabels=exp_config.nlabels)

        # ================================================================
        # save sample results
        # ================================================================
        utils_vis.save_sample_prediction_results(
            x=utils.crop_or_pad_volume_to_size_along_z(image_orig, 256),
            x_norm=utils.crop_or_pad_volume_to_size_along_z(
                normalized_image_orig_res_and_size, 256),
            y_pred=utils.crop_or_pad_volume_to_size_along_z(
                predicted_labels_orig_res_and_size, 256),
            gt=utils.crop_or_pad_volume_to_size_along_z(labels_orig, 256),
            num_rotations=
            -num_rotations,  # rotate for consistent visualization across datasets
            savepath=log_dir + '/' + test_dataset_name + '_' + 'test' + '_' +
            subject_name + file_suffix + '.png')

        # ================================
        # write the mean fg dice of this subject to the text file
        # ================================
        results_file.write(subject_name +
                           ":: dice (mean, std over all FG labels): ")
        results_file.write(
            str(np.round(np.mean(dice_per_label_this_subject[1:]), 3)) + ", " +
            str(np.round(np.std(dice_per_label_this_subject[1:]), 3)))
        dice_per_label_per_subject.append(dice_per_label_this_subject)

        results_file.write(
            ", hausdorff distance (mean, std over all FG labels): ")
        results_file.write(
            str(np.round(np.mean(hsd_per_label_this_subject), 3)) + ", " +
            str(np.round(np.std(dice_per_label_this_subject[1:]), 3)))
        hsd_per_label_per_subject.append(hsd_per_label_this_subject)

        results_file.write("\n")

    # ================================================================
    # write per label statistics over all subjects
    # ================================================================
    dice_per_label_per_subject = np.array(dice_per_label_per_subject)
    hsd_per_label_per_subject = np.array(hsd_per_label_per_subject)

    # ================================
    # In the array images_dice, in the rows, there are subjects
    # and in the columns, there are the dice scores for each label for a particular subject
    # ================================
    results_file.write("================================== \n")
    results_file.write("Label: dice mean, std. deviation over all subjects\n")
    for i in range(dice_per_label_per_subject.shape[1]):
        results_file.write(
            str(i) + ": " +
            str(np.round(np.mean(dice_per_label_per_subject[:, i]), 3)) +
            ", " + str(np.round(np.std(dice_per_label_per_subject[:, i]), 3)) +
            "\n")

    results_file.write("================================== \n")
    results_file.write(
        "Label: hausdorff distance mean, std. deviation over all subjects\n")
    for i in range(hsd_per_label_per_subject.shape[1]):
        results_file.write(
            str(i + 1) + ": " +
            str(np.round(np.mean(hsd_per_label_per_subject[:, i]), 3)) + ", " +
            str(np.round(np.std(hsd_per_label_per_subject[:, i]), 3)) + "\n")

    # ==================
    # write the mean dice over all subjects and all labels
    # ==================
    results_file.write("================================== \n")
    results_file.write(
        "DICE Mean, std. deviation over foreground labels over all subjects: "
        + str(np.round(np.mean(dice_per_label_per_subject[:, 1:]), 3)) + ", " +
        str(np.round(np.std(dice_per_label_per_subject[:, 1:]), 3)) + "\n")
    results_file.write(
        "HSD Mean, std. deviation over labels over all subjects: " +
        str(np.round(np.mean(hsd_per_label_per_subject), 3)) + ", " +
        str(np.round(np.std(hsd_per_label_per_subject), 3)) + "\n")
    results_file.write("================================== \n")
    results_file.close()
Beispiel #5
0
def prepare_data(input_folder, output_file, idx_start, idx_end, protocol, size,
                 depth, target_resolution, preprocessing_folder):

    # ========================
    # read the filenames
    # ========================
    folders_list = sorted(glob.glob(input_folder + '/*/'))
    logging.info('Number of images in the dataset: %s' %
                 str(len(folders_list)))

    # =======================
    # create a hdf5 file
    # =======================
    hdf5_file = h5py.File(output_file, "w")

    # ===============================
    # Create datasets for images and labels
    # ===============================
    data = {}
    num_slices = count_slices(folders_list, idx_start, idx_end, depth)

    data['images'] = hdf5_file.create_dataset("images",
                                              [num_slices] + list(size),
                                              dtype=np.float32)
    data['labels'] = hdf5_file.create_dataset("labels",
                                              [num_slices] + list(size),
                                              dtype=np.uint8)

    # ===============================
    # initialize lists
    # ===============================
    label_list = []
    image_list = []
    nx_list = []
    ny_list = []
    nz_list = []
    px_list = []
    py_list = []
    pz_list = []
    pat_names_list = []

    # ===============================
    # ===============================
    write_buffer = 0
    counter_from = 0

    # ===============================
    # iterate through the requested indices
    # ===============================
    for idx in range(idx_start, idx_end):

        # ==================
        # get file paths
        # ==================
        patient_name, image_path, label_path = get_image_and_label_paths(
            folders_list[idx])

        # ============
        # read the image and normalize it to be between 0 and 1
        # ============
        image, _, image_hdr = utils.load_nii(image_path)
        image = np.swapaxes(
            image, 1, 2
        )  # swap axes 1 and 2 -> this allows appending along axis 2, as in other datasets

        # ==================
        # read the label file
        # ==================
        label, _, _ = utils.load_nii(label_path)
        label = np.swapaxes(
            label, 1, 2
        )  # swap axes 1 and 2 -> this allows appending along axis 2, as in other datasets
        # labels have already been grouped as required

        # ============
        # create a segmentation mask and use it to get rid of the skull in the image
        # ============
        label_mask = np.copy(label)
        label_mask[label > 0] = 1
        image = image * label_mask

        # ==================
        # crop out some portion of the image, which are all zeros (rough registration via visual inspection)
        # ==================
        image, label = center_image_and_label(image, label)

        # plt.figure(); plt.imshow(image[:,:,50], cmap='gray'); plt.title(patient_name); plt.show(); plt.close()

        # ==================
        # crop volume along z axis (as there are several zeros towards the ends)
        # ==================
        image = utils.crop_or_pad_volume_to_size_along_z(image, depth)
        label = utils.crop_or_pad_volume_to_size_along_z(label, depth)

        # ==================
        # collect some header info.
        # ==================
        px_list.append(float(image_hdr.get_zooms()[0]))
        py_list.append(
            float(image_hdr.get_zooms()[2])
        )  # since axes 1 and 2 have been swapped. this is important when dealing with pixel dimensions
        pz_list.append(float(image_hdr.get_zooms()[1]))
        nx_list.append(image.shape[0])
        ny_list.append(
            image.shape[1]
        )  # since axes 1 and 2 have been swapped. however, only the final axis locations are relevant when dealing with shapes
        nz_list.append(image.shape[2])
        pat_names_list.append(patient_name)

        # ==================
        # normalize the image
        # ==================
        image_normalized = utils.normalise_image(image, norm_type='div_by_max')

        # ======================================================
        ### PROCESSING LOOP FOR SLICE-BY-SLICE 2D DATA ###################
        # ======================================================
        scale_vector = [
            image_hdr.get_zooms()[0] / target_resolution[0],
            image_hdr.get_zooms()[2] / target_resolution[1]
        ]  # since axes 1 and 2 have been swapped. this is important when dealing with pixel dimensions

        for zz in range(image.shape[2]):

            # ============
            # rescale the images and labels so that their orientation matches that of the nci dataset
            # ============
            image2d_rescaled = rescale(np.squeeze(image_normalized[:, :, zz]),
                                       scale_vector,
                                       order=1,
                                       preserve_range=True,
                                       multichannel=False,
                                       mode='constant')

            label2d_rescaled = rescale(np.squeeze(label[:, :, zz]),
                                       scale_vector,
                                       order=0,
                                       preserve_range=True,
                                       multichannel=False,
                                       mode='constant')

            # ============
            # rotate to align with other datasets
            # ============
            image2d_rescaled_rotated = np.rot90(image2d_rescaled, k=0)
            label2d_rescaled_rotated = np.rot90(label2d_rescaled, k=0)

            # ============
            # crop or pad to make of the same size
            # ============
            image2d_rescaled_rotated_cropped = utils.crop_or_pad_slice_to_size(
                image2d_rescaled_rotated, size[0], size[1])
            label2d_rescaled_rotated_cropped = utils.crop_or_pad_slice_to_size(
                label2d_rescaled_rotated, size[0], size[1])

            # ============
            # append to list
            # ============
            image_list.append(image2d_rescaled_rotated_cropped)
            label_list.append(label2d_rescaled_rotated_cropped)

            write_buffer += 1

            # Writing needs to happen inside the loop over the slices
            if write_buffer >= MAX_WRITE_BUFFER:

                counter_to = counter_from + write_buffer

                _write_range_to_hdf5(data, image_list, label_list,
                                     counter_from, counter_to)

                _release_tmp_memory(image_list, label_list)

                # update counters
                counter_from = counter_to
                write_buffer = 0

    logging.info('Writing remaining data')
    counter_to = counter_from + write_buffer
    _write_range_to_hdf5(data, image_list, label_list, counter_from,
                         counter_to)
    _release_tmp_memory(image_list, label_list)

    # Write the small datasets
    hdf5_file.create_dataset('nx', data=np.asarray(nx_list, dtype=np.uint16))
    hdf5_file.create_dataset('ny', data=np.asarray(ny_list, dtype=np.uint16))
    hdf5_file.create_dataset('nz', data=np.asarray(nz_list, dtype=np.uint16))
    hdf5_file.create_dataset('px', data=np.asarray(px_list, dtype=np.float32))
    hdf5_file.create_dataset('py', data=np.asarray(py_list, dtype=np.float32))
    hdf5_file.create_dataset('pz', data=np.asarray(pz_list, dtype=np.float32))
    hdf5_file.create_dataset('patnames',
                             data=np.asarray(pat_names_list, dtype="S10"))

    # After test train loop:
    hdf5_file.close()
Beispiel #6
0
def prepare_data(
        input_folder,
        output_file,
        mode,
        size,  # for 3d: (nz, nx, ny), for 2d: (nx, ny)
        target_resolution,  # for 3d: (px, py, pz), for 2d: (px, py)
        cv_fold_num):
    '''
    Main function that prepares a dataset from the raw challenge data to an hdf5 dataset
    '''

    assert (mode in ['2D', '3D']), 'Unknown mode: %s' % mode
    if mode == '2D' and not len(size) == 2:
        raise AssertionError('Inadequate number of size parameters')
    if mode == '3D' and not len(size) == 3:
        raise AssertionError('Inadequate number of size parameters')
    if mode == '2D' and not len(target_resolution) == 2:
        raise AssertionError(
            'Inadequate number of target resolution parameters')
    if mode == '3D' and not len(target_resolution) == 3:
        raise AssertionError(
            'Inadequate number of target resolution parameters')

    # ============
    # create an empty hdf5 file
    # ============
    hdf5_file = h5py.File(output_file, "w")

    # ============
    # create empty lists for filling header info
    # ============
    diag_list = {'test': [], 'train': [], 'validation': []}
    height_list = {'test': [], 'train': [], 'validation': []}
    weight_list = {'test': [], 'train': [], 'validation': []}
    patient_id_list = {'test': [], 'train': [], 'validation': []}
    cardiac_phase_list = {'test': [], 'train': [], 'validation': []}
    nx_list = {'test': [], 'train': [], 'validation': []}
    ny_list = {'test': [], 'train': [], 'validation': []}
    nz_list = {'test': [], 'train': [], 'validation': []}
    px_list = {'test': [], 'train': [], 'validation': []}
    py_list = {'test': [], 'train': [], 'validation': []}
    pz_list = {'test': [], 'train': [], 'validation': []}

    file_list = {'test': [], 'train': [], 'validation': []}
    num_slices = {'test': 0, 'train': 0, 'validation': 0}

    # ============
    # go through all images and get header info.
    # one round of parsing is done just to get all the header info. The size info is used to create empty fields for the images and labels, with the required sizes.
    # Then, another round of reading the images and labels is done, which are pre-processed and written into the hdf5 file
    # ============
    for folder in os.listdir(input_folder):

        folder_path = os.path.join(input_folder, folder)

        if os.path.isdir(folder_path):

            # ============
            # train_test_validation split
            # ============
            train_test = test_train_val_split(patient_id=int(folder[-3:]),
                                              cv_fold_number=cv_fold_num)

            infos = {}
            for line in open(os.path.join(folder_path, 'Info.cfg')):
                label, value = line.split(':')
                infos[label] = value.rstrip('\n').lstrip(' ')

            patient_id = folder.lstrip('patient')

            # ============
            # reading this patient's image and collecting header info
            # ============
            for file in glob.glob(
                    os.path.join(folder_path, 'patient???_frame??_n4.nii.gz')):

                # ============
                # list with file paths
                # ============
                file_list[train_test].append(file)

                diag_list[train_test].append(diagnosis_dict[infos['Group']])
                weight_list[train_test].append(infos['Weight'])
                height_list[train_test].append(infos['Height'])

                patient_id_list[train_test].append(patient_id)

                systole_frame = int(infos['ES'])
                diastole_frame = int(infos['ED'])

                file_base = file.split('.')[0]
                frame = int(file_base.split('frame')[-1][:-3])
                if frame == systole_frame:
                    cardiac_phase_list[train_test].append(1)  # 1 == systole
                elif frame == diastole_frame:
                    cardiac_phase_list[train_test].append(2)  # 2 == diastole
                else:
                    cardiac_phase_list[train_test].append(
                        0)  # 0 means other phase

                nifty_img = nib.load(file)
                nx_list[train_test].append(nifty_img.shape[0])
                ny_list[train_test].append(nifty_img.shape[1])
                nz_list[train_test].append(nifty_img.shape[2])
                num_slices[train_test] += nifty_img.shape[2]
                py_list[train_test].append(
                    nifty_img.header.structarr['pixdim'][2])
                px_list[train_test].append(
                    nifty_img.header.structarr['pixdim'][1])
                pz_list[train_test].append(
                    nifty_img.header.structarr['pixdim'][3])

    # ============
    # writing the small datasets
    # ============
    for tt in ['test', 'train', 'validation']:
        hdf5_file.create_dataset('diagnosis_%s' % tt,
                                 data=np.asarray(diag_list[tt],
                                                 dtype=np.uint8))
        hdf5_file.create_dataset('weight_%s' % tt,
                                 data=np.asarray(weight_list[tt],
                                                 dtype=np.float32))
        hdf5_file.create_dataset('height_%s' % tt,
                                 data=np.asarray(height_list[tt],
                                                 dtype=np.float32))
        hdf5_file.create_dataset('patient_id_%s' % tt,
                                 data=np.asarray(patient_id_list[tt],
                                                 dtype=np.uint8))
        hdf5_file.create_dataset('cardiac_phase_%s' % tt,
                                 data=np.asarray(cardiac_phase_list[tt],
                                                 dtype=np.uint8))
        hdf5_file.create_dataset('nz_%s' % tt,
                                 data=np.asarray(nz_list[tt], dtype=np.uint16))
        hdf5_file.create_dataset('ny_%s' % tt,
                                 data=np.asarray(ny_list[tt], dtype=np.uint16))
        hdf5_file.create_dataset('nx_%s' % tt,
                                 data=np.asarray(nx_list[tt], dtype=np.uint16))
        hdf5_file.create_dataset('py_%s' % tt,
                                 data=np.asarray(py_list[tt],
                                                 dtype=np.float32))
        hdf5_file.create_dataset('px_%s' % tt,
                                 data=np.asarray(px_list[tt],
                                                 dtype=np.float32))
        hdf5_file.create_dataset('pz_%s' % tt,
                                 data=np.asarray(pz_list[tt],
                                                 dtype=np.float32))

    # ============
    # setting sizes according to 2d or 3d
    # ============
    if mode == '3D':  # size [num_patients, nz, nx, ny]
        nz_max, nx, ny = size
        n_train = len(file_list['train'])  # number of patients
        n_test = len(file_list['test'])
        n_val = len(file_list['validation'])

    elif mode == '2D':  # size [num_z_slices_across_all_patients, nx, ny]
        nx, ny = size
        n_test = num_slices['test']
        n_train = num_slices['train']
        n_val = num_slices['validation']

    else:
        raise AssertionError('Wrong mode setting. This should never happen.')

    # ============
    # creating datasets for images and labels
    # ============
    data = {}
    for tt, num_points in zip(['test', 'train', 'validation'],
                              [n_test, n_train, n_val]):

        if num_points > 0:
            data['images_%s' % tt] = hdf5_file.create_dataset(
                "images_%s" % tt, [num_points] + list(size), dtype=np.float32)
            data['labels_%s' % tt] = hdf5_file.create_dataset(
                "labels_%s" % tt, [num_points] + list(size), dtype=np.uint8)

    image_list = {'test': [], 'train': [], 'validation': []}
    label_list = {'test': [], 'train': [], 'validation': []}

    for train_test in ['test', 'train', 'validation']:

        write_buffer = 0
        counter_from = 0
        patient_counter = 0

        for image_file in file_list[train_test]:

            patient_counter += 1

            logging.info('============================================')
            logging.info('Doing: %s' % image_file)

            # ============
            # read image
            # ============
            image_dat = utils.load_nii(image_file)
            image = image_dat[0].copy()

            # ============
            # normalize the image to be between 0 and 1
            # ============
            image = utils.normalise_image(image, norm_type='div_by_max')

            # ============
            # read label
            # ============
            file_base = image_file.split('_n4.nii.gz')[0]
            label_file = file_base + '_gt.nii.gz'
            label_dat = utils.load_nii(label_file)
            label = label_dat[0].copy()

            # ============
            # set RV label to 1 and other labels to 0, as the RVSC dataset only has labels for the RV
            # original labels: 0 bachground, 1 right ventricle, 2 myocardium, 3 left ventricle
            # ============
            # label[label!=1] = 0

            # ============
            # original pixel size (px, py, pz)
            # ============
            pixel_size = (image_dat[2].structarr['pixdim'][1],
                          image_dat[2].structarr['pixdim'][2],
                          image_dat[2].structarr['pixdim'][3])

            # ========================================================================
            # PROCESSING LOOP FOR 3D DATA
            # ========================================================================
            if mode == '3D':

                # rescaling ratio
                scale_vector = [
                    pixel_size[0] / target_resolution[0],
                    pixel_size[1] / target_resolution[1],
                    pixel_size[2] / target_resolution[2]
                ]

                # ==============================
                # rescale image and label
                # ==============================
                image_scaled = transform.rescale(image,
                                                 scale_vector,
                                                 order=1,
                                                 preserve_range=True,
                                                 multichannel=False,
                                                 mode='constant')
                label_scaled = transform.rescale(label,
                                                 scale_vector,
                                                 order=0,
                                                 preserve_range=True,
                                                 multichannel=False,
                                                 mode='constant')

                # ==============================
                # ==============================
                image_scaled = utils.crop_or_pad_volume_to_size_along_z(
                    image_scaled, nz_max)
                label_scaled = utils.crop_or_pad_volume_to_size_along_z(
                    label_scaled, nz_max)

                # ==============================
                # nz_max is the z-dimension provided in the 'size' parameter
                # ==============================
                image_vol = np.zeros((nx, ny, nz_max), dtype=np.float32)
                label_vol = np.zeros((nx, ny, nz_max), dtype=np.uint8)

                # ===============================
                # going through each z slice
                # ===============================
                for zz in range(nz_max):

                    image_slice = image_scaled[:, :, zz]
                    label_slice = label_scaled[:, :, zz]

                    # cropping / padding with zeros the x-y slice at this z location
                    image_slice_cropped = utils.crop_or_pad_slice_to_size(
                        image_slice, nx, ny)
                    label_slice_cropped = utils.crop_or_pad_slice_to_size(
                        label_slice, nx, ny)

                    image_vol[:, :, zz] = image_slice_cropped
                    label_vol[:, :, zz] = label_slice_cropped

                # ===============================
                # swap axes to maintain consistent orientation as compared to 2d pre-processing
                # ===============================
                image_vol = image_vol.swapaxes(0, 2).swapaxes(1, 2)
                label_vol = label_vol.swapaxes(0, 2).swapaxes(1, 2)

                # ===============================
                # append to list that will be written to the hdf5 file
                # ===============================
                image_list[train_test].append(image_vol)
                label_list[train_test].append(label_vol)

                write_buffer += 1

                # ===============================
                # writing the images and labels pre-processed so far to the hdf5 file
                # ===============================
                if write_buffer >= MAX_WRITE_BUFFER:

                    counter_to = counter_from + write_buffer
                    _write_range_to_hdf5(data, train_test, image_list,
                                         label_list, counter_from, counter_to)
                    _release_tmp_memory(image_list, label_list, train_test)

                    # reset stuff for next iteration
                    counter_from = counter_to
                    write_buffer = 0

            # ========================================================================
            # PROCESSING LOOP FOR SLICE-BY-SLICE 2D DATA
            # ========================================================================
            elif mode == '2D':

                scale_vector = [
                    pixel_size[0] / target_resolution[0],
                    pixel_size[1] / target_resolution[1]
                ]

                # ===============================
                # go through each z slice, rescale and crop and append.
                # in this process, the z axis will become the zeroth axis
                # ===============================
                for zz in range(image.shape[2]):

                    image_slice = np.squeeze(image[:, :, zz])
                    label_slice = np.squeeze(label[:, :, zz])

                    image_slice_rescaled = transform.rescale(
                        image_slice,
                        scale_vector,
                        order=1,
                        preserve_range=True,
                        multichannel=False,
                        mode='constant')
                    label_slice_rescaled = transform.rescale(
                        label_slice,
                        scale_vector,
                        order=0,
                        preserve_range=True,
                        multichannel=False,
                        mode='constant')

                    image_slice_cropped = utils.crop_or_pad_slice_to_size(
                        image_slice_rescaled, nx, ny)
                    label_slice_cropped = utils.crop_or_pad_slice_to_size(
                        label_slice_rescaled, nx, ny)

                    image_list[train_test].append(image_slice_cropped)
                    label_list[train_test].append(label_slice_cropped)

                    write_buffer += 1

                    # Writing needs to happen inside the loop over the slices
                    if write_buffer >= MAX_WRITE_BUFFER:

                        counter_to = counter_from + write_buffer
                        _write_range_to_hdf5(data, train_test, image_list,
                                             label_list, counter_from,
                                             counter_to)
                        _release_tmp_memory(image_list, label_list, train_test)

                        # reset stuff for next iteration
                        counter_from = counter_to
                        write_buffer = 0

        logging.info('Writing remaining data')
        counter_to = counter_from + write_buffer

        _write_range_to_hdf5(data, train_test, image_list, label_list,
                             counter_from, counter_to)
        _release_tmp_memory(image_list, label_list, train_test)

    # After test train loop:
    hdf5_file.close()
Beispiel #7
0
def prepare_data(input_folder,
                 output_file,
                 idx_start,
                 idx_end,
                 protocol,
                 size,
                 depth,
                 target_resolution,
                 preprocessing_folder):

    # ========================    
    # read the filenames
    # ========================
    filenames = sorted(glob.glob(input_folder + '*.zip'))
    logging.info('Number of images in the dataset: %s' % str(len(filenames)))

    # =======================
    # create a new hdf5 file
    # =======================
    hdf5_file = h5py.File(output_file, "w")

    # ===============================
    # Create datasets for images and labels
    # ===============================
    data = {}
    num_slices = count_slices(filenames,
                              idx_start,
                              idx_end,
                              protocol,
                              preprocessing_folder,
                              depth)
    
    # ===============================
    # the sizes of the image and label arrays are set as: [(number of coronal slices per subject*number of subjects), size of coronal slices]
    # ===============================
    data['images'] = hdf5_file.create_dataset("images", [num_slices] + list(size), dtype=np.float32)
    data['labels'] = hdf5_file.create_dataset("labels", [num_slices] + list(size), dtype=np.uint8)
    
    # ===============================
    # initialize lists
    # ===============================        
    label_list = []
    image_list = []
    nx_list = []
    ny_list = []
    nz_list = []
    px_list = []
    py_list = []
    pz_list = []
    pat_names_list = []
    
    # ===============================      
    # initialize counters
    # ===============================        
    write_buffer = 0
    counter_from = 0
    
    # ===============================
    # iterate through the requested indices
    # ===============================
    for idx in range(idx_start, idx_end):
        
        # ==================
        # get file paths
        # ==================
        patient_name, image_path, label_path = get_image_and_label_paths(filenames[idx],
                                                                         protocol,
                                                                         preprocessing_folder)
        
        # ============
        # read the image and normalize it to be between 0 and 1
        # ============
        image, _, image_hdr = utils.load_nii(image_path)
        image = np.swapaxes(image, 1, 2) # swap axes 1 and 2 -> this allows appending along axis 2, as in other datasets
        
        # ==================
        # read the label file
        # ==================        
        label, _, _ = utils.load_nii(label_path)        
        label = np.swapaxes(label, 1, 2) # swap axes 1 and 2 -> this allows appending along axis 2, as in other datasets
        label = utils.group_segmentation_classes(label) # group the segmentation classes as required
                
        # ==================
        # crop volume along z axis (as there are several zeros towards the ends)
        # ==================
        image = utils.crop_or_pad_volume_to_size_along_z(image, depth)
        label = utils.crop_or_pad_volume_to_size_along_z(label, depth)     

        # ==================
        # collect some header info.
        # ==================
        px_list.append(float(image_hdr.get_zooms()[0]))
        py_list.append(float(image_hdr.get_zooms()[2])) # since axes 1 and 2 have been swapped
        pz_list.append(float(image_hdr.get_zooms()[1]))
        nx_list.append(image.shape[0]) 
        ny_list.append(image.shape[1]) # since axes 1 and 2 have been swapped
        nz_list.append(image.shape[2])
        pat_names_list.append(patient_name)
        
        # ==================
        # normalize the image
        # ==================
        image_normalized = utils.normalise_image(image, norm_type='div_by_max')
                        
        # ======================================================  
        ### PROCESSING LOOP FOR SLICE-BY-SLICE 2D DATA ###################
        # ======================================================
        scale_vector = [image_hdr.get_zooms()[0] / target_resolution[0],
                        image_hdr.get_zooms()[2] / target_resolution[1]] # since axes 1 and 2 have been swapped

        for zz in range(image.shape[2]):

            # ============
            # rescale the images and labels so that their orientation matches that of the nci dataset
            # ============            
            image2d_rescaled = rescale(np.squeeze(image_normalized[:, :, zz]),
                                                  scale_vector,
                                                  order=1,
                                                  preserve_range=True,
                                                  multichannel=False,
                                                  mode = 'constant')
 
            label2d_rescaled = rescale(np.squeeze(label[:, :, zz]),
                                                  scale_vector,
                                                  order=0,
                                                  preserve_range=True,
                                                  multichannel=False,
                                                  mode='constant')
            
            # ============            
            # crop or pad to make of the same size
            # ============            
            image2d_rescaled_rotated_cropped = utils.crop_or_pad_slice_to_size(image2d_rescaled, size[0], size[1])
            label2d_rescaled_rotated_cropped = utils.crop_or_pad_slice_to_size(label2d_rescaled, size[0], size[1])

            # ============   
            # append to list
            # ============   
            image_list.append(image2d_rescaled_rotated_cropped)
            label_list.append(label2d_rescaled_rotated_cropped)

            # ============   
            # increment counter
            # ============   
            write_buffer += 1

            # ============   
            # Writing needs to happen inside the loop over the slices
            # ============   
            if write_buffer >= MAX_WRITE_BUFFER:

                counter_to = counter_from + write_buffer

                _write_range_to_hdf5(data,
                                     image_list,
                                     label_list,
                                     counter_from,
                                     counter_to)
                
                _release_tmp_memory(image_list,
                                    label_list)

                # ============   
                # update counters 
                # ============   
                counter_from = counter_to
                write_buffer = 0
        
    # ============   
    # write leftover data
    # ============   
    logging.info('Writing remaining data')
    counter_to = counter_from + write_buffer
    _write_range_to_hdf5(data,
                         image_list,
                         label_list,
                         counter_from,
                         counter_to)
    _release_tmp_memory(image_list,
                        label_list)

    # ============   
    # Write the small datasets - image resolutions, sizes, patient ids
    # ============   
    hdf5_file.create_dataset('nx', data=np.asarray(nx_list, dtype=np.uint16))
    hdf5_file.create_dataset('ny', data=np.asarray(ny_list, dtype=np.uint16))
    hdf5_file.create_dataset('nz', data=np.asarray(nz_list, dtype=np.uint16))
    hdf5_file.create_dataset('px', data=np.asarray(px_list, dtype=np.float32))
    hdf5_file.create_dataset('py', data=np.asarray(py_list, dtype=np.float32))
    hdf5_file.create_dataset('pz', data=np.asarray(pz_list, dtype=np.float32))
    hdf5_file.create_dataset('patnames', data=np.asarray(pat_names_list, dtype="S10"))
    
    # ============   
    # close the hdf5 file
    # ============   
    hdf5_file.close()
Beispiel #8
0
def main():

    # ===================================
    # read the test images
    # ===================================
    test_dataset_name = exp_config.test_dataset

    if test_dataset_name is 'HCPT1':
        logging.info('Reading HCPT1 images...')
        logging.info('Data root directory: ' + sys_config.orig_data_root_hcp)

        image_depth = exp_config.image_depth_hcp
        idx_start = 50
        idx_end = 70

        data_brain_test = data_hcp.load_and_maybe_process_data(
            input_folder=sys_config.orig_data_root_hcp,
            preprocessing_folder=sys_config.preproc_folder_hcp,
            idx_start=idx_start,
            idx_end=idx_end,
            protocol='T1',
            size=exp_config.image_size,
            depth=image_depth,
            target_resolution=exp_config.target_resolution_brain)

    elif test_dataset_name is 'HCPT2':
        logging.info('Reading HCPT2 images...')
        logging.info('Data root directory: ' + sys_config.orig_data_root_hcp)

        image_depth = exp_config.image_depth_hcp
        idx_start = 50
        idx_end = 70

        data_brain_test = data_hcp.load_and_maybe_process_data(
            input_folder=sys_config.orig_data_root_hcp,
            preprocessing_folder=sys_config.preproc_folder_hcp,
            idx_start=idx_start,
            idx_end=idx_end,
            protocol='T2',
            size=exp_config.image_size,
            depth=image_depth,
            target_resolution=exp_config.target_resolution_brain)

    elif test_dataset_name is 'CALTECH':
        logging.info('Reading CALTECH images...')
        logging.info('Data root directory: ' +
                     sys_config.orig_data_root_abide + 'CALTECH/')

        image_depth = exp_config.image_depth_caltech
        idx_start = 16
        idx_end = 36

        data_brain_test = data_abide.load_and_maybe_process_data(
            input_folder=sys_config.orig_data_root_abide,
            preprocessing_folder=sys_config.preproc_folder_abide,
            site_name='CALTECH',
            idx_start=idx_start,
            idx_end=idx_end,
            protocol='T1',
            size=exp_config.image_size,
            depth=image_depth,
            target_resolution=exp_config.target_resolution_brain)

    elif test_dataset_name is 'STANFORD':
        logging.info('Reading STANFORD images...')
        logging.info('Data root directory: ' +
                     sys_config.orig_data_root_abide + 'STANFORD/')

        image_depth = exp_config.image_depth_stanford
        idx_start = 16
        idx_end = 36

        data_brain_test = data_abide.load_and_maybe_process_data(
            input_folder=sys_config.orig_data_root_abide,
            preprocessing_folder=sys_config.preproc_folder_abide,
            site_name='STANFORD',
            idx_start=idx_start,
            idx_end=idx_end,
            protocol='T1',
            size=exp_config.image_size,
            depth=image_depth,
            target_resolution=exp_config.target_resolution_brain)

    imts = data_brain_test['images']
    name_test_subjects = data_brain_test['patnames']
    num_test_subjects = imts.shape[0] // image_depth
    ids = np.arange(idx_start, idx_end)

    orig_data_res_x = data_brain_test['px'][:]
    orig_data_res_y = data_brain_test['py'][:]
    orig_data_res_z = data_brain_test['pz'][:]
    orig_data_siz_x = data_brain_test['nx'][:]
    orig_data_siz_y = data_brain_test['ny'][:]
    orig_data_siz_z = data_brain_test['nz'][:]

    # ================================
    # set the log directory
    # ================================
    if exp_config.normalize is True:
        log_dir = os.path.join(sys_config.log_root,
                               exp_config.expname_normalizer)
    else:
        log_dir = sys_config.log_root + 'i2l_mapper/' + exp_config.expname_i2l

    # ================================================================
    # For each test image, load the best model and compute the dice with this model
    # ================================================================
    for sub_num in range(5):

        subject_id_start_slice = np.sum(orig_data_siz_z[:sub_num])
        subject_id_end_slice = np.sum(orig_data_siz_z[:sub_num + 1])
        image = imts[subject_id_start_slice:subject_id_end_slice, :, :]

        # ==================================================================
        # setup logging
        # ==================================================================
        logging.basicConfig(level=logging.INFO,
                            format='%(asctime)s %(message)s')
        subject_name = str(name_test_subjects[sub_num])[2:-1]
        logging.info(
            '============================================================')
        logging.info('Subject id: %s' % sub_num)

        # ==================================================================
        # predict segmentation at the pre-processed resolution
        # ==================================================================
        predicted_labels, normalized_image = predict_segmentation(
            subject_name, image, exp_config.normalize)

        # ==================================================================
        # read the original segmentation mask
        # ==================================================================

        if test_dataset_name is 'HCPT1':
            # image will be normalized to [0,1]
            image_orig, labels_orig = data_hcp.load_without_size_preprocessing(
                input_folder=sys_config.orig_data_root_hcp,
                idx=ids[sub_num],
                protocol='T1',
                preprocessing_folder=sys_config.preproc_folder_hcp,
                depth=image_depth)
            num_rotations = 0

        elif test_dataset_name is 'HCPT2':
            # image will be normalized to [0,1]
            image_orig, labels_orig = data_hcp.load_without_size_preprocessing(
                input_folder=sys_config.orig_data_root_hcp,
                idx=ids[sub_num],
                protocol='T2',
                preprocessing_folder=sys_config.preproc_folder_hcp,
                depth=image_depth)
            num_rotations = 0

        elif test_dataset_name is 'CALTECH':
            # image will be normalized to [0,1]
            image_orig, labels_orig = data_abide.load_without_size_preprocessing(
                input_folder=sys_config.orig_data_root_abide,
                site_name='CALTECH',
                idx=ids[sub_num],
                depth=image_depth)
            num_rotations = 0

        elif test_dataset_name is 'STANFORD':
            # image will be normalized to [0,1]
            image_orig, labels_orig = data_abide.load_without_size_preprocessing(
                input_folder=sys_config.orig_data_root_abide,
                site_name='STANFORD',
                idx=ids[sub_num],
                depth=image_depth)
            num_rotations = 0

        # ==================================================================
        # convert the predicitons back to original resolution
        # ==================================================================
        predicted_labels_orig_res_and_size = rescale_and_crop(
            predicted_labels,
            orig_data_res_x[sub_num],
            orig_data_res_y[sub_num],
            orig_data_siz_x[sub_num],
            orig_data_siz_y[sub_num],
            order_interpolation=0,
            num_rotations=num_rotations)

        normalized_image_orig_res_and_size = rescale_and_crop(
            normalized_image,
            orig_data_res_x[sub_num],
            orig_data_res_y[sub_num],
            orig_data_siz_x[sub_num],
            orig_data_siz_y[sub_num],
            order_interpolation=1,
            num_rotations=num_rotations)

        # ================================================================
        # save sample results
        # ================================================================
        x_true = utils.crop_or_pad_volume_to_size_along_z(image_orig, 256)
        z_true = utils.crop_or_pad_volume_to_size_along_z(labels_orig, 256)
        x_norm = utils.crop_or_pad_volume_to_size_along_z(
            normalized_image_orig_res_and_size, 256)
        z_pred = utils.crop_or_pad_volume_to_size_along_z(
            predicted_labels_orig_res_and_size, 256)

        # basepath = os.path.join(sys_config.log_root, exp_config.expname_normalizer) + '/subject_' + subject_name + '/results/tta' + str(exp_config.normalize)
        basepath = log_dir + '/' + test_dataset_name + '_' + 'test' + '_' + subject_name
        for zz in np.arange(120, 130, 10):
            utils_vis.save_single_image(
                x_true[:, :, zz], basepath + 'slice' + str(zz) + '_x_true.png',
                15, False, 'gray', False)
            utils_vis.save_single_image(x_norm[:, :, zz],
                                        basepath + 'slice' + str(zz) +
                                        '_x_norm.png',
                                        15,
                                        False,
                                        'gray',
                                        False,
                                        climits=[-1.0, 2.0])
            utils_vis.save_single_image(
                z_true[:, :, zz], basepath + 'slice' + str(zz) + '_z_true.png',
                15, True, 'tab20', False)
            utils_vis.save_single_image(
                z_pred[:, :, zz], basepath + 'slice' + str(zz) + '_z_pred.png',
                15, True, 'tab20', False)
def main():
    
    # ===================================
    # read the test images
    # ===================================
    if exp_config.evaluate_td is True:
        test_dataset_name = exp_config.test_dataset
    else:
        test_dataset_name = exp_config.train_dataset
    
    if test_dataset_name is 'HCPT1':
        logging.info('Reading HCPT1 images...')    
        logging.info('Data root directory: ' + sys_config.orig_data_root_hcp)
        
        image_depth = exp_config.image_depth_hcp
        idx_start = 50
        idx_end = 70       
        
        data_test = data_hcp.load_and_maybe_process_data(input_folder = sys_config.orig_data_root_hcp,
                                                         preprocessing_folder = sys_config.preproc_folder_hcp,
                                                         idx_start = idx_start,
                                                         idx_end = idx_end,                
                                                         protocol = 'T1',
                                                         size = exp_config.image_size,
                                                         depth = image_depth,
                                                         target_resolution = exp_config.target_resolution)
        
        imts = data_test['images']
        name_test_subjects = data_test['patnames']
        num_test_subjects = imts.shape[0] // image_depth
        ids = np.arange(idx_start, idx_end)       
        
        orig_data_res_x = data_test['px'][:]
        orig_data_res_y = data_test['py'][:]
        orig_data_res_z = data_test['pz'][:]
        orig_data_siz_x = data_test['nx'][:]
        orig_data_siz_y = data_test['ny'][:]
        orig_data_siz_z = data_test['nz'][:]
        
    elif test_dataset_name is 'HCPT2':
        logging.info('Reading HCPT2 images...')    
        logging.info('Data root directory: ' + sys_config.orig_data_root_hcp)
        
        image_depth = exp_config.image_depth_hcp
        idx_start = 50
        idx_end = 70
        
        data_test = data_hcp.load_and_maybe_process_data(input_folder = sys_config.orig_data_root_hcp,
                                                         preprocessing_folder = sys_config.preproc_folder_hcp,
                                                         idx_start = idx_start,
                                                         idx_end = idx_end,           
                                                         protocol = 'T2',
                                                         size = exp_config.image_size,
                                                         depth = image_depth,
                                                         target_resolution = exp_config.target_resolution)
        
        imts = data_test['images']
        name_test_subjects = data_test['patnames']
        num_test_subjects = imts.shape[0] // image_depth
        ids = np.arange(idx_start, idx_end)       
        
        orig_data_res_x = data_test['px'][:]
        orig_data_res_y = data_test['py'][:]
        orig_data_res_z = data_test['pz'][:]
        orig_data_siz_x = data_test['nx'][:]
        orig_data_siz_y = data_test['ny'][:]
        orig_data_siz_z = data_test['nz'][:]
        
    elif test_dataset_name is 'CALTECH':
        logging.info('Reading CALTECH images...')    
        logging.info('Data root directory: ' + sys_config.orig_data_root_abide + 'CALTECH/')
        
        image_depth = exp_config.image_depth_caltech
        idx_start = 16
        idx_end = 36         
        
        data_test = data_abide.load_and_maybe_process_data(input_folder = sys_config.orig_data_root_abide,
                                                           preprocessing_folder = sys_config.preproc_folder_abide,
                                                           site_name = 'CALTECH',
                                                           idx_start = idx_start,
                                                           idx_end = idx_end,             
                                                           protocol = 'T1',
                                                           size = exp_config.image_size,
                                                           depth = image_depth,
                                                           target_resolution = exp_config.target_resolution)        
    
        imts = data_test['images']
        name_test_subjects = data_test['patnames']
        num_test_subjects = imts.shape[0] // image_depth
        ids = np.arange(idx_start, idx_end)       
        
        orig_data_res_x = data_test['px'][:]
        orig_data_res_y = data_test['py'][:]
        orig_data_res_z = data_test['pz'][:]
        orig_data_siz_x = data_test['nx'][:]
        orig_data_siz_y = data_test['ny'][:]
        orig_data_siz_z = data_test['nz'][:]
            
    elif test_dataset_name is 'NCI':
        data_test = data_nci.load_and_maybe_process_data(input_folder=sys_config.orig_data_root_nci,
                                                         preprocessing_folder=sys_config.preproc_folder_nci,
                                                         size=exp_config.image_size,
                                                         target_resolution=exp_config.target_resolution,
                                                         force_overwrite=False,
                                                         cv_fold_num = 1)

        imts = data_test['images_test']
        name_test_subjects = data_test['patnames_test']

        orig_data_res_x = data_test['px_test'][:]
        orig_data_res_y = data_test['py_test'][:]
        orig_data_res_z = data_test['pz_test'][:]
        orig_data_siz_x = data_test['nx_test'][:]
        orig_data_siz_y = data_test['ny_test'][:]
        orig_data_siz_z = data_test['nz_test'][:]

        num_test_subjects = orig_data_siz_z.shape[0]
        ids = np.arange(num_test_subjects)

    elif test_dataset_name is 'PIRAD_ERC':

        idx_start = 0
        idx_end = 20
        ids = np.arange(idx_start, idx_end)

        data_test = data_pirad_erc.load_data(input_folder=sys_config.orig_data_root_pirad_erc,
                                             preproc_folder=sys_config.preproc_folder_pirad_erc,
                                             idx_start=idx_start,
                                             idx_end=idx_end,
                                             size=exp_config.image_size,
                                             target_resolution=exp_config.target_resolution,
                                             labeller='ek')
        imts = data_test['images']
        name_test_subjects = data_test['patnames']

        orig_data_res_x = data_test['px'][:]
        orig_data_res_y = data_test['py'][:]
        orig_data_res_z = data_test['pz'][:]
        orig_data_siz_x = data_test['nx'][:]
        orig_data_siz_y = data_test['ny'][:]
        orig_data_siz_z = data_test['nz'][:]

        num_test_subjects = orig_data_siz_z.shape[0]
        
    # ================================   
    # set the log directory
    # ================================   
    if exp_config.normalize is True:
        log_dir = os.path.join(sys_config.log_root, exp_config.expname_normalizer)
    else:
        if exp_config.uda is False:        
            log_dir = sys_config.log_root + exp_config.expname_i2l
        else:
            log_dir = sys_config.log_root + exp_config.expname_uda

    # ================================   
    # open a text file for writing the mean dice scores for each subject that is evaluated
    # ================================       
    results_file = open(log_dir + '/' + test_dataset_name + '_' + 'test' + '.txt', "w")
    results_file.write("================================== \n") 
    results_file.write("Test results \n") 
    
    # ================================================================
    # For each test image, load the best model and compute the dice with this model
    # ================================================================
    dice_per_label_per_subject = []
    hsd_per_label_per_subject = []

    for sub_num in range(5): #(num_test_subjects): 

        subject_id_start_slice = np.sum(orig_data_siz_z[:sub_num])
        subject_id_end_slice = np.sum(orig_data_siz_z[:sub_num+1])
        image = imts[subject_id_start_slice:subject_id_end_slice,:,:]  
        
        # ==================================================================
        # setup logging
        # ==================================================================
        logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')
        subject_name = str(name_test_subjects[sub_num])[2:-1]
        logging.info('============================================================')
        logging.info('Subject id: %s' %sub_num)
    
        # ==================================================================
        # predict segmentation at the pre-processed resolution
        # ==================================================================
        predicted_labels, normalized_image = predict_segmentation(subject_name,
                                                                  image,
                                                                  exp_config.normalize)

        # ==================================================================
        # read the original segmentation mask
        # ==================================================================
        if test_dataset_name is 'HCPT1':
            # image will be normalized to [0,1]
            image_orig, labels_orig = data_hcp.load_without_size_preprocessing(input_folder = sys_config.orig_data_root_hcp,
                                                                              idx = ids[sub_num],
                                                                              protocol = 'T1',
                                                                              preprocessing_folder = sys_config.preproc_folder_hcp,
                                                                              depth = image_depth)
            num_rotations = 0  
            
        elif test_dataset_name is 'HCPT2':
            # image will be normalized to [0,1]
            image_orig, labels_orig = data_hcp.load_without_size_preprocessing(input_folder = sys_config.orig_data_root_hcp,
                                                                              idx = ids[sub_num],
                                                                              protocol = 'T2',
                                                                              preprocessing_folder = sys_config.preproc_folder_hcp,
                                                                              depth = image_depth)
            num_rotations = 0  

        elif test_dataset_name is 'CALTECH':
            # image will be normalized to [0,1]
            image_orig, labels_orig = data_abide.load_without_size_preprocessing(input_folder = sys_config.orig_data_root_abide,
                                                                               site_name = 'CALTECH',
                                                                               idx = ids[sub_num],
                                                                               depth = image_depth)
            num_rotations = 0

        elif test_dataset_name is 'STANFORD':
            # image will be normalized to [0,1]
            image_orig, labels_orig = data_abide.load_without_size_preprocessing(input_folder = sys_config.orig_data_root_abide,
                                                                               site_name = 'STANFORD',
                                                                               idx = ids[sub_num],
                                                                               depth = image_depth)
            num_rotations = 0
            
        elif test_dataset_name is 'NCI':
            # image will be normalized to [0,1]
            image_orig, labels_orig = data_nci.load_without_size_preprocessing(sys_config.orig_data_root_nci,
                                                                               cv_fold_num=1,
                                                                               train_test='test',
                                                                               idx=ids[sub_num])
            num_rotations = 0

        elif test_dataset_name is 'PIRAD_ERC':
            # image will be normalized to [0,1]
            image_orig, labels_orig = data_pirad_erc.load_without_size_preprocessing(sys_config.orig_data_root_pirad_erc,
                                                                                     ids[sub_num],
                                                                                     labeller='ek')
            num_rotations = -3
            
        # ==================================================================
        # convert the predicitons back to original resolution
        # ==================================================================
        predicted_labels_orig_res_and_size = rescale_and_crop(predicted_labels,
                                                              orig_data_res_x[sub_num],
                                                              orig_data_res_y[sub_num],
                                                              orig_data_siz_x[sub_num],
                                                              orig_data_siz_y[sub_num],
                                                              order_interpolation = 0,
                                                              num_rotations = num_rotations)
        
        normalized_image_orig_res_and_size = rescale_and_crop(normalized_image,
                                                              orig_data_res_x[sub_num],
                                                              orig_data_res_y[sub_num],
                                                              orig_data_siz_x[sub_num],
                                                              orig_data_siz_y[sub_num],
                                                              order_interpolation = 1,
                                                              num_rotations = num_rotations)
        
        # ==================================================================
        # If only whole-gland comparisions are desired, merge the labels in both ground truth segmentations as well as the predictions
        # ==================================================================
        if exp_config.whole_gland_results is True:
            predicted_labels_orig_res_and_size[predicted_labels_orig_res_and_size!=0] = 1
            labels_orig[labels_orig!=0] = 1
            nl = 2
            savepath = log_dir + '/' + test_dataset_name + '_test_' + subject_name + '_whole_gland.png'
        else:
            nl = exp_config.nlabels
            savepath = log_dir + '/' + test_dataset_name + '_test_' + subject_name + '.png'
            
        # ==================================================================
        # compute dice at the original resolution
        # ==================================================================    
        dice_per_label_this_subject = met.f1_score(labels_orig.flatten(),
                                                   predicted_labels_orig_res_and_size.flatten(),
                                                   average=None)
        
        # ==================================================================    
        # compute Hausforff distance at the original resolution
        # ==================================================================    
        compute_hsd = False
        if compute_hsd is True:
            hsd_per_label_this_subject = utils.compute_surface_distance(y1 = labels_orig,
                                                                        y2 = predicted_labels_orig_res_and_size,
                                                                        nlabels = exp_config.nlabels)
        else:
            hsd_per_label_this_subject = np.zeros((exp_config.nlabels))
        
        # ================================================================
        # save sample results
        # ================================================================
        d_vis = 32 # 256
        ids_vis = np.arange(0, 32, 4) # ids = np.arange(48, 256-48, (256-96)//8)
        utils_vis.save_sample_prediction_results(x = utils.crop_or_pad_volume_to_size_along_z(image_orig, d_vis),
                                                 x_norm = utils.crop_or_pad_volume_to_size_along_z(normalized_image_orig_res_and_size, d_vis),
                                                 y_pred = utils.crop_or_pad_volume_to_size_along_z(predicted_labels_orig_res_and_size, d_vis),
                                                 gt = utils.crop_or_pad_volume_to_size_along_z(labels_orig, d_vis),
                                                 num_rotations = - num_rotations, # rotate for consistent visualization across datasets
                                                 savepath = savepath,
                                                 nlabels = nl,
                                                 ids=ids_vis)
                                   
        # ================================
        # write the mean fg dice of this subject to the text file
        # ================================
        results_file.write(subject_name + ":: dice (mean, std over all FG labels): ")
        results_file.write(str(np.round(np.mean(dice_per_label_this_subject[1:]), 3)) + ", " + str(np.round(np.std(dice_per_label_this_subject[1:]), 3)))
        results_file.write(", hausdorff distance (mean, std over all FG labels): ")
        results_file.write(str(np.round(np.mean(hsd_per_label_this_subject), 3)) + ", " + str(np.round(np.std(dice_per_label_this_subject[1:]), 3)) + "\n")
        
        dice_per_label_per_subject.append(dice_per_label_this_subject)
        hsd_per_label_per_subject.append(hsd_per_label_this_subject)
    
    # ================================================================
    # write per label statistics over all subjects    
    # ================================================================
    dice_per_label_per_subject = np.array(dice_per_label_per_subject)
    hsd_per_label_per_subject =  np.array(hsd_per_label_per_subject)
    
    # ================================
    # In the array images_dice, in the rows, there are subjects
    # and in the columns, there are the dice scores for each label for a particular subject
    # ================================
    results_file.write("================================== \n") 
    results_file.write("Label: dice mean, std. deviation over all subjects\n")
    for i in range(dice_per_label_per_subject.shape[1]):
        results_file.write(str(i) + ": " + str(np.round(np.mean(dice_per_label_per_subject[:,i]), 3)) + ", " + str(np.round(np.std(dice_per_label_per_subject[:,i]), 3)) + "\n")
    results_file.write("================================== \n") 
    results_file.write("Label: hausdorff distance mean, std. deviation over all subjects\n")
    for i in range(hsd_per_label_per_subject.shape[1]):
        results_file.write(str(i+1) + ": " + str(np.round(np.mean(hsd_per_label_per_subject[:,i]), 3)) + ", " + str(np.round(np.std(hsd_per_label_per_subject[:,i]), 3)) + "\n")
    
    # ==================
    # write the mean dice over all subjects and all labels
    # ==================
    results_file.write("================================== \n") 
    results_file.write("DICE Mean, std. deviation over foreground labels over all subjects: " + str(np.round(np.mean(dice_per_label_per_subject[:,1:]), 3)) + ", " + str(np.round(np.std(dice_per_label_per_subject[:,1:]), 3)) + "\n")
    results_file.write("HSD Mean, std. deviation over labels over all subjects: " + str(np.round(np.mean(hsd_per_label_per_subject), 3)) + ", " + str(np.round(np.std(hsd_per_label_per_subject), 3)) + "\n")
    results_file.write("================================== \n") 
    results_file.close()
Beispiel #10
0
def prepare_data(input_folder,
                 output_file,
                 idx_start,
                 idx_end,
                 protocol,
                 size,
                 depth,
                 target_resolution,
                 preprocessing_folder):

    # ========================    
    # read the filenames
    # ========================
    filenames = sorted(glob.glob(input_folder + '*.zip'))
    logging.info('Number of images in the dataset: %s' % str(len(filenames)))

    # =======================
    # =======================
    hdf5_file = h5py.File(output_file, "w")

    # ===============================
    # Create datasets for images and labels
    # ===============================
    data = {}
    num_subjects = idx_end - idx_start
    
    data['images'] = hdf5_file.create_dataset("images", [num_subjects] + list(size), dtype=np.float32)
    data['labels'] = hdf5_file.create_dataset("labels", [num_subjects] + list(size), dtype=np.uint8)
    
    # ===============================
    # initialize lists
    # ===============================        
    label_list = []
    image_list = []
    nx_list = []
    ny_list = []
    nz_list = []
    px_list = []
    py_list = []
    pz_list = []
    pat_names_list = []
    
    # ===============================        
    # initiate counter
    # ===============================        
    patient_counter = 0
    
    # ===============================
    # iterate through the requested indices
    # ===============================
    for idx in range(idx_start, idx_end):
        
        # ==================
        # get file paths
        # ==================
        patient_name, image_path, label_path = get_image_and_label_paths(filenames[idx],
                                                                         protocol,
                                                                         preprocessing_folder)
        
        # ============
        # read the image and normalize it to be between 0 and 1
        # ============
        image, _, image_hdr = utils.load_nii(image_path)
        image = np.swapaxes(image, 1, 2) # swap axes 1 and 2 -> this allows appending along axis 2, as in other datasets
        
        # ==================
        # read the label file
        # ==================        
        label, _, _ = utils.load_nii(label_path)        
        label = np.swapaxes(label, 1, 2) # swap axes 1 and 2 -> this allows appending along axis 2, as in other datasets
        label = utils.group_segmentation_classes(label) # group the segmentation classes as required
        
        # ==================
        # collect some header info.
        # ==================
        px_list.append(float(image_hdr.get_zooms()[0]))
        py_list.append(float(image_hdr.get_zooms()[2])) # since axes 1 and 2 have been swapped
        pz_list.append(float(image_hdr.get_zooms()[1]))
        nx_list.append(image.shape[0]) 
        ny_list.append(image.shape[2]) # since axes 1 and 2 have been swapped
        nz_list.append(image.shape[1])
        pat_names_list.append(patient_name)
        
        # ==================
        # crop volume along z axis (as there are several zeros towards the ends)
        # ==================
        image = utils.crop_or_pad_volume_to_size_along_z(image, depth)
        label = utils.crop_or_pad_volume_to_size_along_z(label, depth)     
        
        # ==================
        # normalize the image
        # ==================
        image_normalized = utils.normalise_image(image, norm_type='div_by_max')
                        
        # ======================================================  
        # rescale, crop / pad to make all images of the required size and resolution
        # ======================================================
        scale_vector = [image_hdr.get_zooms()[0] / target_resolution[0],
                        image_hdr.get_zooms()[2] / target_resolution[1],
                        image_hdr.get_zooms()[1] / target_resolution[2]] # since axes 1 and 2 have been swapped
        
        image_rescaled = transform.rescale(image_normalized,
                                           scale_vector,
                                           order=1,
                                           preserve_range=True,
                                           multichannel=False,
                                           mode = 'constant')

        label_onehot = utils.make_onehot_(label, nlabels=15)

        label_onehot_rescaled = transform.rescale(label_onehot,
                                                  scale_vector,
                                                  order=1,
                                                  preserve_range=True,
                                                  multichannel=True,
                                                  mode='constant')
        
        label_rescaled = np.argmax(label_onehot_rescaled, axis=-1)
        
        # ==================================
        # go through each z slice, crop or pad to a constant size and then append the resized 
        # this will ensure that the axes get arranged in the same orientation as they were during the 2d preprocessing
        # ==================================
        image_rescaled_cropped = []
        label_rescaled_cropped = []
        for zz in range(image_rescaled.shape[2]):
            image_rescaled_cropped.append(utils.crop_or_pad_slice_to_size(image_rescaled[:,:,zz], size[1], size[2]))
            label_rescaled_cropped.append(utils.crop_or_pad_slice_to_size(label_rescaled[:,:,zz], size[1], size[2]))
        image_rescaled_cropped = np.array(image_rescaled_cropped)
        label_rescaled_cropped = np.array(label_rescaled_cropped)

        # ============   
        # append to list
        # ============   
        image_list.append(image_rescaled_cropped)
        label_list.append(label_rescaled_cropped)

        # ============   
        # write to file
        # ============   
        _write_range_to_hdf5(data,
                             image_list,
                             label_list,
                             patient_counter,
                             patient_counter+1)
        
        _release_tmp_memory(image_list,
                            label_list)
        
        # update counter
        patient_counter += 1

    # Write the small datasets
    hdf5_file.create_dataset('nx', data=np.asarray(nx_list, dtype=np.uint16))
    hdf5_file.create_dataset('ny', data=np.asarray(ny_list, dtype=np.uint16))
    hdf5_file.create_dataset('nz', data=np.asarray(nz_list, dtype=np.uint16))
    hdf5_file.create_dataset('px', data=np.asarray(px_list, dtype=np.float32))
    hdf5_file.create_dataset('py', data=np.asarray(py_list, dtype=np.float32))
    hdf5_file.create_dataset('pz', data=np.asarray(pz_list, dtype=np.float32))
    hdf5_file.create_dataset('patnames', data=np.asarray(pat_names_list, dtype="S10"))
    
    # After test train loop:
    hdf5_file.close()