Ejemplo n.º 1
0
    def create_one_3D_file():
        '''
        Create one big file which contains all 3D Images (not slices).
        '''

        class HP:
            DATASET = "HCP"
            RESOLUTION = "1.25mm"
            FEATURES_FILENAME = "270g_125mm_peaks"
            LABELS_TYPE = np.int16
            DATASET_FOLDER = "HCP"

        data_all = []
        seg_all = []

        print("\n\nProcessing Data...")
        for s in get_all_subjects():
            print("processing data subject {}".format(s))
            data = nib.load(join(C.HOME, HP.DATASET_FOLDER, s, HP.FEATURES_FILENAME + ".nii.gz")).get_data()
            data = np.nan_to_num(data)
            data = DatasetUtils.scale_input_to_unet_shape(data, HP.DATASET, HP.RESOLUTION)
        data_all.append(np.array(data))
        np.save("data.npy", data_all)
        del data_all  # free memory

        print("\n\nProcessing Segs...")
        for s in get_all_subjects():
            print("processing seg subject {}".format(s))
            seg = ImgUtils.create_multilabel_mask(HP, s, labels_type=HP.LABELS_TYPE)
            if HP.RESOLUTION == "2.5mm":
                seg = ImgUtils.resize_first_three_dims(seg, order=0, zoom=0.5)
            seg = DatasetUtils.scale_input_to_unet_shape(seg, HP.DATASET, HP.RESOLUTION)
        seg_all.append(np.array(seg))
        print("SEG TYPE: {}".format(seg_all.dtype))
        np.save("seg.npy", seg_all)
Ejemplo n.º 2
0
    def scale_input_to_unet_shape(img4d, dataset, resolution="1.25mm"):
        '''
        Scale input image to right isotropic resolution and pad/cut image to make it square to fit UNet input shape

        :param img4d: (x, y, z, userdefined)  (userdefined could be gradients or classes)
        :param resolution: "1.25mm" / "2mm" / "2.5mm"     results in UNet input shape of (144,144,144) or (80,80,80)
        :return: img with dim 1mm: (144,144,144,none) or 2mm: (80,80,80,none) or 2.5mm: (80,80,80,none)
                    (note: 2.5mm padded with more zeros to reach 80,80,80)
        '''

        if resolution == "1.25mm":
            if dataset == "HCP":  # (145,174,145)
                # no resize needed
                return img4d[1:, 15:159, 1:]  # (144,144,144)
            elif dataset == "HCP_32g":  # (73,87,73)
                # return img4d[1:, 15:159, 1:]  # (144,144,144) #OLD when HCP_32g was still 125mm
                img4d = ImgUtils.resize_first_three_dims(img4d, zoom=2)  # (146,174,146,none)
                img4d = img4d[:-1,:,:-1]  #remove one voxel that came from upsampling   #(145,174,145)
                return img4d[1:, 15:159, 1:]  # (144,144,144)
            elif dataset == "TRACED":  # (78,93,75)
                raise ValueError("resolution '1.25mm' not supported for dataset 'TRACED'")

        elif resolution == "2mm":
            if dataset == "HCP":  # (145,174,145)
                img4d = ImgUtils.resize_first_three_dims(img4d, zoom=0.62)  # (90,108,90)
                return img4d[5:85, 14:94, 5:85, :]  # (80,80,80)
            elif dataset == "HCP_32g":  # (145,174,145)
                img4d = ImgUtils.resize_first_three_dims(img4d, zoom=0.62)  # (90,108,90)
                return img4d[5:85, 14:94, 5:85, :]  # (80,80,80)
            elif dataset == "HCP_2mm":  # (90,108,90)
                # no resize needed
                return img4d[5:85, 14:94, 5:85, :]  # (80,80,80)
            elif dataset == "TRACED":  # (78,93,75)
                raise ValueError("resolution '2mm' not supported for dataset 'TRACED'")

        elif resolution == "2.5mm":
            if dataset == "HCP":  # (145,174,145)
                img4d = ImgUtils.resize_first_three_dims(img4d, zoom=0.5)  # (73,87,73,none)
                bg = np.zeros((80, 80, 80, img4d.shape[3])).astype(img4d.dtype)
                bg = bg + img4d[0,0,0,:] #make bg have same value as bg from original img  (this adds last dim of img4d to last dim of bg)
                bg[4:77, :, 4:77] = img4d[:, 4:84, :, :]
                return bg  # (80,80,80)
            elif dataset == "HCP_2.5mm":  # (73,87,73,none)
                #no resize needed
                bg = np.zeros((80, 80, 80, img4d.shape[3])).astype(img4d.dtype)
                bg = bg + img4d[0,0,0,:] #make bg have same value as bg from original img  (this adds last dim of img4d to last dim of bg)
                bg[4:77, :, 4:77] = img4d[:, 4:84, :, :]
                return bg  # (80,80,80)
            elif dataset == "HCP_32g":  # (73,87,73,none)
                bg = np.zeros((80, 80, 80, img4d.shape[3])).astype(img4d.dtype)
                bg = bg + img4d[0, 0, 0, :]  # make bg have same value as bg from original img  (this adds last dim of img4d to last dim of bg)
                bg[4:77, :, 4:77] = img4d[:, 4:84, :, :]
                return bg  # (80,80,80)
            elif dataset == "TRACED":  # (78,93,75)
                # no resize needed
                bg = np.zeros((80, 80, 80, img4d.shape[3])).astype(img4d.dtype)
                bg = bg + img4d[0, 0, 0, :]  # make bg have same value as bg from original img
                bg[1:79, :, 3:78, :] = img4d[:, 7:87, :, :]
                return bg  # (80,80,80)
Ejemplo n.º 3
0
    def precompute_batches(custom_type=None):
        '''
        9000 slices per epoch -> 200 batches (batchsize=44) per epoch
        => 200-1000 batches needed


        270g_125mm_bundle_peaks_Y: no DAug, no Norm, only Y
        All_sizes_DAug_XYZ: 12g, 90g, 270g, DAug (no rotation, no elastic deform), Norm, XYZ
        270g_125mm_bundle_peaks_XYZ: no DAug, Norm, XYZ
        '''

        class HP:
            NORMALIZE_DATA = True
            DATA_AUGMENTATION = False
            CV_FOLD = 0
            INPUT_DIM = (144, 144)
            BATCH_SIZE = 44
            DATASET_FOLDER = "HCP"
            TYPE = "single_direction"
            EXP_PATH = "~"
            LABELS_FILENAME = "bundle_peaks"
            FEATURES_FILENAME = "270g_125mm_peaks"
            DATASET = "HCP"
            RESOLUTION = "1.25mm"
            LABELS_TYPE = np.float32

        HP.TRAIN_SUBJECTS, HP.VALIDATE_SUBJECTS, HP.TEST_SUBJECTS = ExpUtils.get_cv_fold(HP.CV_FOLD)

        num_batches_base = 5000
        num_batches = {
            "train": num_batches_base,
            "validate": int(num_batches_base / 3.),
            "test": int(num_batches_base / 3.),
        }

        if custom_type is None:
            types = ["train", "validate", "test"]
        else:
            types = [custom_type]

        for type in types:
            dataManager = DataManagerTrainingNiftiImgs(HP)
            batch_gen = dataManager.get_batches(batch_size=HP.BATCH_SIZE, type=type,
                                                subjects=getattr(HP, type.upper() + "_SUBJECTS"), num_batches=num_batches[type])

            for idx, batch in enumerate(batch_gen):
                print("Processing: {}".format(idx))

                # DATASET_DIR = "HCP_batches/270g_125mm_bundle_peaks_Y"
                # DATASET_DIR = "HCP_batches/All_sizes_DAug_XYZ"
                DATASET_DIR = "HCP_batches/270g_125mm_bundle_peaks_XYZ"
                ExpUtils.make_dir(join(C.HOME, DATASET_DIR, type))

                data = nib.Nifti1Image(batch["data"], ImgUtils.get_dwi_affine(HP.DATASET, HP.RESOLUTION))
                nib.save(data, join(C.HOME, DATASET_DIR, type, "batch_" + str(idx) + "_data.nii.gz"))

                seg = nib.Nifti1Image(batch["seg"], ImgUtils.get_dwi_affine(HP.DATASET, HP.RESOLUTION))
                nib.save(seg, join(C.HOME, DATASET_DIR, type, "batch_" + str(idx) + "_seg.nii.gz"))
Ejemplo n.º 4
0
    def _create_prob_slices_file(HP, subjects, filename, bundle, shuffle=True):

        mask_dir = join(C.HOME, HP.DATASET_FOLDER)

        input_dir = HP.MULTI_PARENT_PATH

        combined_slices = []
        mask_slices = []

        for s in subjects:
            print("processing subject {}".format(s))

            probs_x = nib.load(join(input_dir, "UNet_x_" + str(HP.CV_FOLD), "probmaps", s + "_probmap.nii.gz")).get_data()
            probs_y = nib.load(join(input_dir, "UNet_y_" + str(HP.CV_FOLD), "probmaps", s + "_probmap.nii.gz")).get_data()
            probs_z = nib.load(join(input_dir, "UNet_z_" + str(HP.CV_FOLD), "probmaps", s + "_probmap.nii.gz")).get_data()
            # probs_x = DatasetUtils.scale_input_to_unet_shape(probs_x, HP.DATASET, HP.RESOLUTION)
            # probs_y = DatasetUtils.scale_input_to_unet_shape(probs_y, HP.DATASET, HP.RESOLUTION)
            # probs_z = DatasetUtils.scale_input_to_unet_shape(probs_z, HP.DATASET, HP.RESOLUTION)
            combined = np.stack((probs_x, probs_y, probs_z), axis=4)  # (73, 87, 73, 18, 3)  #not working alone: one dim too much for UNet -> reshape
            combined = np.reshape(combined, (combined.shape[0], combined.shape[1], combined.shape[2],
                                             combined.shape[3] * combined.shape[4]))    # (73, 87, 73, 3*18)

            # print("combined shape after", combined.shape)

            mask_data = ImgUtils.create_multilabel_mask(HP, s, labels_type=HP.LABELS_TYPE)
            if HP.DATASET == "HCP_2mm":
                #use "HCP" because for mask we need downscaling
                mask_data = DatasetUtils.scale_input_to_unet_shape(mask_data, "HCP", HP.RESOLUTION)
            elif HP.DATASET == "HCP_2.5mm":
                # use "HCP" because for mask we need downscaling
                mask_data = DatasetUtils.scale_input_to_unet_shape(mask_data, "HCP", HP.RESOLUTION)
            else:
                # Mask has same resolution as probmaps -> we can use same resizing
                mask_data = DatasetUtils.scale_input_to_unet_shape(mask_data, HP.DATASET, HP.RESOLUTION)

            # Save as Img
            img = nib.Nifti1Image(combined, ImgUtils.get_dwi_affine(HP.DATASET, HP.RESOLUTION))
            nib.save(img, join(HP.EXP_PATH, "combined", s + "_combinded_probmap.nii.gz"))


            combined = DatasetUtils.scale_input_to_unet_shape(combined, HP.DATASET, HP.RESOLUTION)
            assert (combined.shape[2] == mask_data.shape[2])

            #Save as Slices
            for z in range(combined.shape[2]):
                combined_slices.append(combined[:, :, z, :])
                mask_slices.append(mask_data[:, :, z, :])

        if shuffle:
            combined_slices, mask_slices = sk_shuffle(combined_slices, mask_slices, random_state=9)

        if HP.TRAIN:
            np.save(filename + "_data.npy", combined_slices)
            np.save(filename + "_seg.npy", mask_slices)
Ejemplo n.º 5
0
    def move_to_MNI_space(input_file, bvals, bvecs, brain_mask, output_dir):
        print("Moving input to MNI space...")

        os.system("calc_FA -i " + input_file + " -o " + output_dir +
                  "/FA.nii.gz --bvals " + bvals + " --bvecs " + bvecs +
                  " --brain_mask " + brain_mask)

        dwi_spacing = ImgUtils.get_image_spacing(input_file)

        template_path = resource_filename('examples.resources',
                                          'MNI_FA_template.nii.gz')

        os.system(
            "flirt -ref " + template_path + " -in " + output_dir +
            "/FA.nii.gz -out " + output_dir + "/FA_MNI.nii.gz -omat " +
            output_dir +
            "/FA_2_MNI.mat -dof 6 -cost mutualinfo -searchcost mutualinfo")

        os.system("flirt -ref " + template_path + " -in " + input_file +
                  " -out " + output_dir +
                  "/Diffusion_MNI.nii.gz -applyisoxfm " + dwi_spacing +
                  " -init " + output_dir + "/FA_2_MNI.mat -dof 6")
        os.system("cp " + bvals + " " + output_dir + "/Diffusion_MNI.bvals")
        os.system("cp " + bvecs + " " + output_dir + "/Diffusion_MNI.bvecs")

        new_input_file = join(output_dir, "Diffusion_MNI.nii.gz")
        bvecs = join(output_dir, "Diffusion_MNI.bvecs")
        bvals = join(output_dir, "Diffusion_MNI.bvals")

        brain_mask = Mrtrix.create_brain_mask(new_input_file, output_dir)

        return new_input_file, bvals, bvecs, brain_mask
Ejemplo n.º 6
0
    def cut_and_scale_img_back_to_original_img(data, t):
        '''
        Undo the transformations done with pad_and_scale_img_to_square_img

        data: 3D or 4D image
        t: transformation dict
        '''

        # Back to old size
        # use order=0, otherwise image values of a DWI will be quite different after downsampling and upsampling
        if len(data.shape) == 3:
            new_data = ndimage.zoom(data, (1. / t["zoom"]), order=0)
        elif len(data.shape) == 4:
            new_data = ImgUtils.resize_first_three_dims(data, order=0, zoom=(1. / t["zoom"]))

        x_residual = 0
        y_residual = 0
        z_residual = 0

        # check if has 0.5 residual -> we have to cut 1 pixel more at the end
        if t["pad_x"] - int(t["pad_x"]) == 0.5:
            x_residual = 1
        if t["pad_y"] - int(t["pad_y"]) == 0.5:
            y_residual = 1
        if t["pad_z"] - int(t["pad_z"]) == 0.5:
            z_residual = 1

        # Cut padding
        shape = new_data.shape
        new_data = new_data[int(t["pad_x"]): shape[0] - int(t["pad_x"]) - x_residual,
                   int(t["pad_y"]): shape[1] - int(t["pad_y"]) - y_residual,
                   int(t["pad_z"]): shape[2] - int(t["pad_z"]) - z_residual]

        return new_data
Ejemplo n.º 7
0
    def cut_and_scale_img_back_to_original_img(data, t):
        '''
        Undo the transformations done with pad_and_scale_img_to_square_img

        data: 3D or 4D image
        t: transformation dict
        '''
        nr_dims = len(data.shape)
        assert (nr_dims >= 3 and nr_dims <= 4)

        # Back to old size
        # use order=0, otherwise image values of a DWI will be quite different after downsampling and upsampling
        if nr_dims == 3:
            new_data = ndimage.zoom(data, (1. / t["zoom"]), order=0)
        elif nr_dims == 4:
            new_data = ImgUtils.resize_first_three_dims(data, order=0, zoom=(1. / t["zoom"]))

        x_residual = 0
        y_residual = 0
        z_residual = 0

        # check if has 0.5 residual -> we have to cut 1 pixel more at the end
        if t["pad_x"] - int(t["pad_x"]) == 0.5:
            x_residual = 1
        if t["pad_y"] - int(t["pad_y"]) == 0.5:
            y_residual = 1
        if t["pad_z"] - int(t["pad_z"]) == 0.5:
            z_residual = 1

        # Cut padding
        shape = new_data.shape
        new_data = new_data[int(t["pad_x"]): shape[0] - int(t["pad_x"]) - x_residual,
                            int(t["pad_y"]): shape[1] - int(t["pad_y"]) - y_residual,
                            int(t["pad_z"]): shape[2] - int(t["pad_z"]) - z_residual]
        return new_data
Ejemplo n.º 8
0
    def save_fusion_nifti_as_npy():

        #Can leave this always the same (for 270g and 32g)
        class HP:
            DATASET = "HCP"
            RESOLUTION = "1.25mm"
            FEATURES_FILENAME = "270g_125mm_peaks"
            LABELS_TYPE = np.int16
            LABELS_FILENAME = "bundle_masks"
            DATASET_FOLDER = "HCP"

        #change this for 270g and 32g
        DIFFUSION_FOLDER = "32g_25mm"

        subjects = get_all_subjects()
        # fold0 = ['687163', '685058', '683256', '680957', '679568', '677968', '673455', '672756', '665254', '654754', '645551', '644044', '638049', '627549', '623844', '622236', '620434', '613538', '601127', '599671', '599469']
        # fold1 = ['992774', '991267', '987983', '984472', '983773', '979984', '978578', '965771', '965367', '959574', '958976', '957974', '951457', '932554', '930449', '922854', '917255', '912447', '910241', '907656', '904044']
        # fold2 = ['901442', '901139', '901038', '899885', '898176', '896879', '896778', '894673', '889579', '887373', '877269', '877168', '872764', '872158', '871964', '871762', '865363', '861456', '859671', '857263', '856766']
        # fold3 = ['849971', '845458', '837964', '837560', '833249', '833148', '826454', '826353', '816653', '814649', '802844', '792766', '792564', '789373', '786569', '784565', '782561', '779370', '771354', '770352', '765056']
        # fold4 = ['761957', '759869', '756055', '753251', '751348', '749361', '748662', '748258', '742549', '734045', '732243', '729557', '729254', '715647', '715041', '709551', '705341', '704238', '702133', '695768', '690152']
        # subjects = fold2 + fold3 + fold4

        # subjects = ['654754', '645551', '644044', '638049', '627549', '623844', '622236', '620434', '613538', '601127', '599671', '599469']

        print("\n\nProcessing Data...")
        for s in subjects:
            print("processing data subject {}".format(s))
            start_time = time.time()
            data = nib.load(
                join(C.NETWORK_DRIVE, "HCP_fusion_" + DIFFUSION_FOLDER,
                     s + "_probmap.nii.gz")).get_data()
            print("Done Loading")
            data = np.nan_to_num(data)
            data = DatasetUtils.scale_input_to_unet_shape(
                data, HP.DATASET, HP.RESOLUTION)
            data = data[:-1, :, :
                        -1, :]  # cut one pixel at the end, because in scale_input_to_world_shape we ouputted 146 -> one too much at the end
            ExpUtils.make_dir(
                join(C.NETWORK_DRIVE, "HCP_fusion_npy_" + DIFFUSION_FOLDER, s))
            np.save(
                join(C.NETWORK_DRIVE, "HCP_fusion_npy_" + DIFFUSION_FOLDER, s,
                     DIFFUSION_FOLDER + "_xyz.npy"), data)
            print("Took {}s".format(time.time() - start_time))

            print("processing seg subject {}".format(s))
            start_time = time.time()
            # seg = ImgUtils.create_multilabel_mask(HP, s, labels_type=HP.LABELS_TYPE)
            seg = nib.load(
                join(C.NETWORK_DRIVE, "HCP_for_training_COPY", s,
                     HP.LABELS_FILENAME + ".nii.gz")).get_data()
            if HP.RESOLUTION == "2.5mm":
                seg = ImgUtils.resize_first_three_dims(seg, order=0, zoom=0.5)
            seg = DatasetUtils.scale_input_to_unet_shape(
                seg, HP.DATASET, HP.RESOLUTION)
            np.save(
                join(C.NETWORK_DRIVE, "HCP_fusion_npy_" + DIFFUSION_FOLDER, s,
                     "bundle_masks.npy"), seg)
            print("Took {}s".format(time.time() - start_time))
Ejemplo n.º 9
0
    def move_to_subject_space(output_dir):
        print("Moving input to subject space...")

        file_path_in = output_dir + "/bundle_segmentations.nii.gz"
        file_path_out = output_dir + "/bundle_segmentations_subjectSpace.nii.gz"
        dwi_spacing = ImgUtils.get_image_spacing(file_path_in)
        os.system("convert_xfm -omat " + output_dir + "/MNI_2_FA.mat -inverse " + output_dir + "/FA_2_MNI.mat")
        os.system("flirt -ref " + output_dir + "/FA.nii.gz -in " + file_path_in + " -out " + file_path_out +
                  " -applyisoxfm " + dwi_spacing + " -init " + output_dir + "/MNI_2_FA.mat -dof 6")
        os.system("fslmaths " + file_path_out + " -thr 0.5 -bin " + file_path_out)
Ejemplo n.º 10
0
    def create_one_3D_file():
        '''
        Create one big file which contains all 3D Images (not slices).
        '''
        class HP:
            DATASET = "HCP"
            RESOLUTION = "1.25mm"
            FEATURES_FILENAME = "270g_125mm_peaks"
            LABELS_TYPE = np.int16
            DATASET_FOLDER = "HCP"

        data_all = []
        seg_all = []

        print("\n\nProcessing Data...")
        for s in get_all_subjects():
            print("processing data subject {}".format(s))
            data = nib.load(
                join(C.HOME, HP.DATASET_FOLDER, s,
                     HP.FEATURES_FILENAME + ".nii.gz")).get_data()
            data = np.nan_to_num(data)
            data = DatasetUtils.scale_input_to_unet_shape(
                data, HP.DATASET, HP.RESOLUTION)
        data_all.append(np.array(data))
        np.save("data.npy", data_all)
        del data_all  # free memory

        print("\n\nProcessing Segs...")
        for s in get_all_subjects():
            print("processing seg subject {}".format(s))
            seg = ImgUtils.create_multilabel_mask(HP,
                                                  s,
                                                  labels_type=HP.LABELS_TYPE)
            if HP.RESOLUTION == "2.5mm":
                seg = ImgUtils.resize_first_three_dims(seg, order=0, zoom=0.5)
            seg = DatasetUtils.scale_input_to_unet_shape(
                seg, HP.DATASET, HP.RESOLUTION)
        seg_all.append(np.array(seg))
        print("SEG TYPE: {}".format(seg_all.dtype))
        np.save("seg.npy", seg_all)
Ejemplo n.º 11
0
    def scale_input_to_world_shape(img4d, dataset, resolution="1.25mm"):
        '''
        Scale input image to original resolution and pad/cut image to make it original size

        :param img4d: (x, y, z, userdefined)  (userdefined could be gradients or classes)
        :param resolution: "1.25mm" / "2mm" / "2.5mm"
        :return: img with original size
        '''

        if resolution == "1.25mm":
            if dataset == "HCP":  # (144,144,144)
                # no resize needed
                return ImgUtils.pad_4d_image_left(img4d, np.array([1,15,1,0]), [146,174,146,img4d.shape[3]], pad_value=0)  # (146, 174, 146, none)
            elif dataset == "HCP_32g":  # (144,144,144)
                # no resize needed
                return ImgUtils.pad_4d_image_left(img4d, np.array([1,15,1,0]), [146,174,146,img4d.shape[3]], pad_value=0)  # (146, 174, 146, none)
            elif dataset == "TRACED":  # (78,93,75)
                raise ValueError("resolution '1.25mm' not supported for dataset 'TRACED'")

        elif resolution == "2mm":
            if dataset == "HCP":  # (80,80,80)
                return ImgUtils.pad_4d_image_left(img4d, np.array([5,14,5,0]), [90,108,90,img4d.shape[3]], pad_value=0)  # (90, 108, 90, none)
            elif dataset == "HCP_32g":  # (80,80,80)
                return ImgUtils.pad_4d_image_left(img4d, np.array([5,14,5,0]), [90,108,90,img4d.shape[3]], pad_value=0)  # (90, 108, 90, none)
            elif dataset == "HCP_2mm":  # (80,80,80)
                return ImgUtils.pad_4d_image_left(img4d, np.array([5,14,5,0]), [90,108,90,img4d.shape[3]], pad_value=0)  # (90, 108, 90, none)
            elif dataset == "TRACED":  # (78,93,75)
                raise ValueError("resolution '2mm' not supported for dataset 'TRACED'")

        elif resolution == "2.5mm":
            if dataset == "HCP":  # (80,80,80)
                img4d = ImgUtils.pad_4d_image_left(img4d, np.array([0,4,0,0]), [80,87,80,img4d.shape[3]], pad_value=0) # (80,87,80,none)
                return img4d[4:77,:,4:77, :] # (73, 87, 73, none)
            elif dataset == "HCP_2.5mm":  # (80,80,80)
                img4d = ImgUtils.pad_4d_image_left(img4d, np.array([0,4,0,0]), [80,87,80,img4d.shape[3]], pad_value=0)  # (80,87,80,none)
                return img4d[4:77,:,4:77,:]  # (73, 87, 73, none)
            elif dataset == "HCP_32g":  # ((80,80,80)
                img4d = ImgUtils.pad_4d_image_left(img4d, np.array([0, 4, 0, 0]), [80, 87, 80, img4d.shape[3]], pad_value=0)  # (80,87,80,none)
                return img4d[4:77, :, 4:77, :]  # (73, 87, 73, none)
            elif dataset == "TRACED":  # (80,80,80)
                img4d = ImgUtils.pad_4d_image_left(img4d, np.array([0,7,0,0]), [80,93,80,img4d.shape[3]],pad_value=0)  # (80,93,80,none)
                return img4d[1:79, :, 3:78, :]  # (78,93,75,none)
Ejemplo n.º 12
0
    def pad_and_scale_img_to_square_img(data, target_size=144):
        '''
        Expects 3D or 4D image as input.

        Does
        1. Pad image with 0 to make it square
            (if uneven padding -> adds one more px "behind" img; but resulting img shape will be correct)
        2. Scale image to UNet size (144, 144, 144)
        '''
        nr_dims = len(data.shape)
        assert (nr_dims >= 3 and nr_dims <= 4)

        shape = data.shape
        biggest_dim = max(shape)

        # Pad to make square
        if nr_dims == 4:
            new_img = np.zeros((biggest_dim, biggest_dim, biggest_dim,
                                shape[3])).astype(data.dtype)
        else:
            new_img = np.zeros(
                (biggest_dim, biggest_dim, biggest_dim)).astype(data.dtype)
        pad1 = (biggest_dim - shape[0]) / 2.
        pad2 = (biggest_dim - shape[1]) / 2.
        pad3 = (biggest_dim - shape[2]) / 2.
        new_img[int(pad1):int(pad1) + shape[0],
                int(pad2):int(pad2) + shape[1],
                int(pad3):int(pad3) + shape[2]] = data

        # Scale to right size
        zoom = float(target_size) / biggest_dim
        if nr_dims == 4:
            #use order=0, otherwise image values of a DWI will be quite different after downsampling and upsampling
            new_img = ImgUtils.resize_first_three_dims(new_img,
                                                       order=0,
                                                       zoom=zoom)
        else:
            new_img = ndimage.zoom(new_img, zoom, order=0)

        transformation = {
            "original_shape": shape,
            "pad_x": pad1,
            "pad_y": pad2,
            "pad_z": pad3,
            "zoom": zoom
        }

        return new_img, transformation
Ejemplo n.º 13
0
    def save_fusion_nifti_as_npy():

        #Can leave this always the same (for 270g and 32g)
        class HP:
            DATASET = "HCP"
            RESOLUTION = "1.25mm"
            FEATURES_FILENAME = "270g_125mm_peaks"
            LABELS_TYPE = np.int16
            LABELS_FILENAME = "bundle_masks"
            DATASET_FOLDER = "HCP"

        #change this for 270g and 32g
        DIFFUSION_FOLDER = "32g_25mm"

        subjects = get_all_subjects()
        # fold0 = ['687163', '685058', '683256', '680957', '679568', '677968', '673455', '672756', '665254', '654754', '645551', '644044', '638049', '627549', '623844', '622236', '620434', '613538', '601127', '599671', '599469']
        # fold1 = ['992774', '991267', '987983', '984472', '983773', '979984', '978578', '965771', '965367', '959574', '958976', '957974', '951457', '932554', '930449', '922854', '917255', '912447', '910241', '907656', '904044']
        # fold2 = ['901442', '901139', '901038', '899885', '898176', '896879', '896778', '894673', '889579', '887373', '877269', '877168', '872764', '872158', '871964', '871762', '865363', '861456', '859671', '857263', '856766']
        # fold3 = ['849971', '845458', '837964', '837560', '833249', '833148', '826454', '826353', '816653', '814649', '802844', '792766', '792564', '789373', '786569', '784565', '782561', '779370', '771354', '770352', '765056']
        # fold4 = ['761957', '759869', '756055', '753251', '751348', '749361', '748662', '748258', '742549', '734045', '732243', '729557', '729254', '715647', '715041', '709551', '705341', '704238', '702133', '695768', '690152']
        # subjects = fold2 + fold3 + fold4

        # subjects = ['654754', '645551', '644044', '638049', '627549', '623844', '622236', '620434', '613538', '601127', '599671', '599469']

        print("\n\nProcessing Data...")
        for s in subjects:
            print("processing data subject {}".format(s))
            start_time = time.time()
            data = nib.load(join(C.NETWORK_DRIVE, "HCP_fusion_" + DIFFUSION_FOLDER, s + "_probmap.nii.gz")).get_data()
            print("Done Loading")
            data = np.nan_to_num(data)
            data = DatasetUtils.scale_input_to_unet_shape(data, HP.DATASET, HP.RESOLUTION)
            data = data[:-1, :, :-1, :]  # cut one pixel at the end, because in scale_input_to_world_shape we ouputted 146 -> one too much at the end
            ExpUtils.make_dir(join(C.NETWORK_DRIVE, "HCP_fusion_npy_" + DIFFUSION_FOLDER, s))
            np.save(join(C.NETWORK_DRIVE, "HCP_fusion_npy_" + DIFFUSION_FOLDER, s, DIFFUSION_FOLDER + "_xyz.npy"), data)
            print("Took {}s".format(time.time() - start_time))

            print("processing seg subject {}".format(s))
            start_time = time.time()
            # seg = ImgUtils.create_multilabel_mask(HP, s, labels_type=HP.LABELS_TYPE)
            seg = nib.load(join(C.NETWORK_DRIVE, "HCP_for_training_COPY", s, HP.LABELS_FILENAME + ".nii.gz")).get_data()
            if HP.RESOLUTION == "2.5mm":
                seg = ImgUtils.resize_first_three_dims(seg, order=0, zoom=0.5)
            seg = DatasetUtils.scale_input_to_unet_shape(seg, HP.DATASET, HP.RESOLUTION)
            np.save(join(C.NETWORK_DRIVE, "HCP_fusion_npy_" + DIFFUSION_FOLDER, s, "bundle_masks.npy"), seg)
            print("Took {}s".format(time.time() - start_time))
Ejemplo n.º 14
0
    def pad_and_scale_img_to_square_img(data, target_size=144):
        '''
        Expects 3D or 4D image as input.

        Does
        1. Pad image with 0 to make it square
            (if uneven padding -> adds one more px "behind" img; but resulting img shape will be correct)
        2. Scale image to UNet size (144, 144, 144)
        '''
        nr_dims = len(data.shape)
        assert (nr_dims >= 3 and nr_dims <= 4)

        shape = data.shape
        biggest_dim = max(shape)

        # Pad to make square
        if nr_dims == 4:
            new_img = np.zeros((biggest_dim, biggest_dim, biggest_dim, shape[3])).astype(data.dtype)
        else:
            new_img = np.zeros((biggest_dim, biggest_dim, biggest_dim)).astype(data.dtype)
        pad1 = (biggest_dim - shape[0]) / 2.
        pad2 = (biggest_dim - shape[1]) / 2.
        pad3 = (biggest_dim - shape[2]) / 2.
        new_img[int(pad1):int(pad1) + shape[0],
                int(pad2):int(pad2) + shape[1],
                int(pad3):int(pad3) + shape[2]] = data

        # Scale to right size
        zoom = float(target_size) / biggest_dim
        if nr_dims == 4:
            #use order=0, otherwise image values of a DWI will be quite different after downsampling and upsampling
            new_img = ImgUtils.resize_first_three_dims(new_img, order=0, zoom=zoom)
        else:
            new_img = ndimage.zoom(new_img, zoom, order=0)

        transformation = {
            "original_shape": shape,
            "pad_x": pad1,
            "pad_y": pad2,
            "pad_z": pad3,
            "zoom": zoom
        }

        return new_img, transformation
Ejemplo n.º 15
0
    def scale_input_to_unet_shape(img4d, dataset, resolution="1.25mm"):
        '''
        Scale input image to right isotropic resolution and pad/cut image to make it square to fit UNet input shape

        :param img4d: (x, y, z, userdefined)  (userdefined could be gradients or classes)
        :param resolution: "1.25mm" / "2mm" / "2.5mm"     results in UNet input shape of (144,144,144) or (80,80,80)
        :return: img with dim 1mm: (144,144,144,none) or 2mm: (80,80,80,none) or 2.5mm: (80,80,80,none)
                    (note: 2.5mm padded with more zeros to reach 80,80,80)
        '''

        if resolution == "1.25mm":
            if dataset == "HCP":  # (145,174,145)
                # no resize needed
                return img4d[1:, 15:159, 1:]  # (144,144,144)
            elif dataset == "HCP_32g":  # (73,87,73)
                # return img4d[1:, 15:159, 1:]  # (144,144,144) #OLD when HCP_32g was still 125mm
                img4d = ImgUtils.resize_first_three_dims(img4d, zoom=2)  # (146,174,146,none)
                img4d = img4d[:-1,:,:-1]  #remove one voxel that came from upsampling   #(145,174,145)
                return img4d[1:, 15:159, 1:]  # (144,144,144)
            elif dataset == "TRACED":  # (78,93,75)
                raise ValueError("resolution '1.25mm' not supported for dataset 'TRACED'")
            elif dataset == "Schizo":  # (91,109,91)
                img4d = ImgUtils.resize_first_three_dims(img4d, zoom=1.60)  # (146,174,146)
                return img4d[1:145, 15:159, 1:145]                                # (144,144,144)

        elif resolution == "2mm":
            if dataset == "HCP":  # (145,174,145)
                img4d = ImgUtils.resize_first_three_dims(img4d, zoom=0.62)  # (90,108,90)
                return img4d[5:85, 14:94, 5:85, :]  # (80,80,80)
            elif dataset == "HCP_32g":  # (145,174,145)
                img4d = ImgUtils.resize_first_three_dims(img4d, zoom=0.62)  # (90,108,90)
                return img4d[5:85, 14:94, 5:85, :]  # (80,80,80)
            elif dataset == "HCP_2mm":  # (90,108,90)
                # no resize needed
                return img4d[5:85, 14:94, 5:85, :]  # (80,80,80)
            elif dataset == "TRACED":  # (78,93,75)
                raise ValueError("resolution '2mm' not supported for dataset 'TRACED'")
            elif dataset == "Schizo":  # (91,109,91)
                return img4d[:, 9:100, :]                                # (91,91,91)

        elif resolution == "2.5mm":
            if dataset == "HCP":  # (145,174,145)
                img4d = ImgUtils.resize_first_three_dims(img4d, zoom=0.5)  # (73,87,73,none)
                bg = np.zeros((80, 80, 80, img4d.shape[3])).astype(img4d.dtype)
                bg = bg + img4d[0,0,0,:] #make bg have same value as bg from original img  (this adds last dim of img4d to last dim of bg)
                bg[4:77, :, 4:77] = img4d[:, 4:84, :, :]
                return bg  # (80,80,80)
            elif dataset == "HCP_2.5mm":  # (73,87,73,none)
                #no resize needed
                bg = np.zeros((80, 80, 80, img4d.shape[3])).astype(img4d.dtype)
                bg = bg + img4d[0,0,0,:] #make bg have same value as bg from original img  (this adds last dim of img4d to last dim of bg)
                bg[4:77, :, 4:77] = img4d[:, 4:84, :, :]
                return bg  # (80,80,80)
            elif dataset == "HCP_32g":  # (73,87,73,none)
                bg = np.zeros((80, 80, 80, img4d.shape[3])).astype(img4d.dtype)
                bg = bg + img4d[0, 0, 0, :]  # make bg have same value as bg from original img  (this adds last dim of img4d to last dim of bg)
                bg[4:77, :, 4:77] = img4d[:, 4:84, :, :]
                return bg  # (80,80,80)
            elif dataset == "TRACED":  # (78,93,75)
                # no resize needed
                bg = np.zeros((80, 80, 80, img4d.shape[3])).astype(img4d.dtype)
                bg = bg + img4d[0, 0, 0, :]  # make bg have same value as bg from original img
                bg[1:79, :, 3:78, :] = img4d[:, 7:87, :, :]
                return bg  # (80,80,80)
Ejemplo n.º 16
0
def run_tractseg(data, output_type="tract_segmentation", input_type="peaks",
                 single_orientation=False, verbose=False, dropout_sampling=False, threshold=0.5,
                 bundle_specific_threshold=False, get_probs=False):
    '''
    Run TractSeg

    :param data: input peaks (4D numpy array with shape [x,y,z,9])
    :param output_type: "tract_segmentation" | "endings_segmentation" | "TOM" | "dm_regression"
    :param input_type: "peaks"
    :param verbose: show debugging infos
    :param dropout_sampling: create uncertainty map by monte carlo dropout (https://arxiv.org/abs/1506.02142)
    :param threshold: Threshold for converting probability map to binary map
    :param bundle_specific_threshold: Threshold is lower for some bundles which need more sensitivity (CA, CST, FX)
    :param get_probs: Output raw probability map instead of binary map
    :return: 4D numpy array with the output of tractseg
        for tract_segmentation:     [x,y,z,nr_of_bundles]
        for endings_segmentation:   [x,y,z,2*nr_of_bundles]
        for TOM:                    [x,y,z,3*nr_of_bundles]
    '''
    start_time = time.time()

    config = get_config_name(input_type, output_type)
    HP = getattr(importlib.import_module("tractseg.config.PretrainedModels." + config), "HP")()
    HP.VERBOSE = verbose
    HP.TRAIN = False
    HP.TEST = False
    HP.SEGMENT = False
    HP.GET_PROBS = get_probs
    HP.LOAD_WEIGHTS = True
    HP.DROPOUT_SAMPLING = dropout_sampling
    HP.THRESHOLD = threshold

    if bundle_specific_threshold:
        HP.GET_PROBS = True

    if input_type == "peaks":
        if HP.EXPERIMENT_TYPE == "tract_segmentation" and HP.DROPOUT_SAMPLING:
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_dropout_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_12g90g270g_125mm_DAugAll_Dropout", "best_weights_ep114.npz")
        elif HP.EXPERIMENT_TYPE == "tract_segmentation":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_T1_12g90g270g_125mm_DAugAll", "best_weights_ep126.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg72_888", "best_weights_ep247.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg72_888_SchizoFineT_lr001", "best_weights_ep186.npz")
        elif HP.EXPERIMENT_TYPE == "endings_segmentation":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_endings_segmentation_v2.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "EndingsSeg_12g90g270g_125mm_DAugAll", "best_weights_ep16.npz")
        elif HP.EXPERIMENT_TYPE == "peak_regression":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_peak_regression_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "x_Pretrained_TractSeg_Models/Peaks20_12g90g270g_125mm_DAugSimp_constW5", "best_weights_ep441.npz")  #more oversegmentation with DAug
        elif HP.EXPERIMENT_TYPE == "dm_regression":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_dm_regression_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "DmReg_12g90g270g_125mm_DAugAll_Ubuntu", "best_weights_ep80.npz")
    elif input_type == "T1":
        if HP.EXPERIMENT_TYPE == "tract_segmentation":
            # HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_v1.npz")
            HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_T1_125mm_DAugAll", "best_weights_ep142.npz")
        elif HP.EXPERIMENT_TYPE == "endings_segmentation":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_endings_segmentation_v1.npz")
        elif HP.EXPERIMENT_TYPE == "peak_regression":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_peak_regression_v1.npz")
    print("Loading weights from: {}".format(HP.WEIGHTS_PATH))

    if HP.EXPERIMENT_TYPE == "peak_regression":
        HP.NR_OF_CLASSES = 3*len(ExpUtils.get_bundle_names(HP.CLASSES)[1:])
    else:
        HP.NR_OF_CLASSES = len(ExpUtils.get_bundle_names(HP.CLASSES)[1:])

    if HP.VERBOSE:
        print("Hyperparameters:")
        ExpUtils.print_HPs(HP)

    Utils.download_pretrained_weights(experiment_type=HP.EXPERIMENT_TYPE, dropout_sampling=HP.DROPOUT_SAMPLING)

    data = np.nan_to_num(data)
    # brain_mask = ImgUtils.simple_brain_mask(data)
    # if HP.VERBOSE:
    #     nib.save(nib.Nifti1Image(brain_mask, np.eye(4)), "otsu_brain_mask_DEBUG.nii.gz")

    if input_type == "T1":
        data = np.reshape(data, (data.shape[0], data.shape[1], data.shape[2], 1))
    data, seg_None, bbox, original_shape = DatasetUtils.crop_to_nonzero(data)
    data, transformation = DatasetUtils.pad_and_scale_img_to_square_img(data, target_size=HP.INPUT_DIM[0])

    model = BaseModel(HP)

    if HP.EXPERIMENT_TYPE == "tract_segmentation" or HP.EXPERIMENT_TYPE == "endings_segmentation" or HP.EXPERIMENT_TYPE == "dm_regression":
        if single_orientation:     # mainly needed for testing because of less RAM requirements
            dataManagerSingle = DataManagerSingleSubjectByFile(HP, data=data)
            trainerSingle = Trainer(model, dataManagerSingle)
            if HP.DROPOUT_SAMPLING or HP.EXPERIMENT_TYPE == "dm_regression" or HP.GET_PROBS:
                seg, img_y = trainerSingle.get_seg_single_img(HP, probs=True, scale_to_world_shape=False, only_prediction=True)
            else:
                seg, img_y = trainerSingle.get_seg_single_img(HP, probs=False, scale_to_world_shape=False, only_prediction=True)
        else:
            seg_xyz, gt = DirectionMerger.get_seg_single_img_3_directions(HP, model, data=data, scale_to_world_shape=False, only_prediction=True)
            if HP.DROPOUT_SAMPLING or HP.EXPERIMENT_TYPE == "dm_regression" or HP.GET_PROBS:
                seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=True)
            else:
                seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=False)

    elif HP.EXPERIMENT_TYPE == "peak_regression":
        dataManagerSingle = DataManagerSingleSubjectByFile(HP, data=data)
        trainerSingle = Trainer(model, dataManagerSingle)
        seg, img_y = trainerSingle.get_seg_single_img(HP, probs=True, scale_to_world_shape=False, only_prediction=True)
        if bundle_specific_threshold:
            seg = ImgUtils.remove_small_peaks_bundle_specific(seg, ExpUtils.get_bundle_names(HP.CLASSES)[1:], len_thr=0.3)
        else:
            seg = ImgUtils.remove_small_peaks(seg, len_thr=0.3)  # set lower for more sensitivity
        #3 dir for Peaks -> not working (?)
        # seg_xyz, gt = DirectionMerger.get_seg_single_img_3_directions(HP, model, data=data, scale_to_world_shape=False, only_prediction=True)
        # seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=True)

    if bundle_specific_threshold and HP.EXPERIMENT_TYPE == "tract_segmentation":
        seg = ImgUtils.probs_to_binary_bundle_specific(seg, ExpUtils.get_bundle_names(HP.CLASSES)[1:])

    #remove following two lines to keep super resolution
    seg = DatasetUtils.cut_and_scale_img_back_to_original_img(seg, transformation)
    seg = DatasetUtils.add_original_zero_padding_again(seg, bbox, original_shape, HP.NR_OF_CLASSES)
    ExpUtils.print_verbose(HP, "Took {}s".format(round(time.time() - start_time, 2)))

    return seg
Ejemplo n.º 17
0
    def scale_input_to_world_shape(img4d, dataset, resolution="1.25mm"):
        '''
        Scale input image to original resolution and pad/cut image to make it original size

        :param img4d: (x, y, z, userdefined)  (userdefined could be gradients or classes)
        :param resolution: "1.25mm" / "2mm" / "2.5mm"
        :return: img with original size
        '''

        if resolution == "1.25mm":
            if dataset == "HCP":  # (144,144,144)
                # no resize needed
                return ImgUtils.pad_4d_image_left(img4d, np.array([1,15,1,0]), [146,174,146,img4d.shape[3]], pad_value=0)  # (146, 174, 146, none)
            elif dataset == "HCP_32g":  # (144,144,144)
                # no resize needed
                return ImgUtils.pad_4d_image_left(img4d, np.array([1,15,1,0]), [146,174,146,img4d.shape[3]], pad_value=0)  # (146, 174, 146, none)
            elif dataset == "TRACED":  # (78,93,75)
                raise ValueError("resolution '1.25mm' not supported for dataset 'TRACED'")
            elif dataset == "Schizo":  # (144,144,144)
                img4d = ImgUtils.pad_4d_image_left(img4d, np.array([1,15,1,0]), [145,174,145,img4d.shape[3]], pad_value=0)  # (145, 174, 145, none)
                return ImgUtils.resize_first_three_dims(img4d, zoom=0.62)  # (91,109,91)

        elif resolution == "2mm":
            if dataset == "HCP":  # (80,80,80)
                return ImgUtils.pad_4d_image_left(img4d, np.array([5,14,5,0]), [90,108,90,img4d.shape[3]], pad_value=0)  # (90, 108, 90, none)
            elif dataset == "HCP_32g":  # (80,80,80)
                return ImgUtils.pad_4d_image_left(img4d, np.array([5,14,5,0]), [90,108,90,img4d.shape[3]], pad_value=0)  # (90, 108, 90, none)
            elif dataset == "HCP_2mm":  # (80,80,80)
                return ImgUtils.pad_4d_image_left(img4d, np.array([5,14,5,0]), [90,108,90,img4d.shape[3]], pad_value=0)  # (90, 108, 90, none)
            elif dataset == "TRACED":  # (78,93,75)
                raise ValueError("resolution '2mm' not supported for dataset 'TRACED'")

        elif resolution == "2.5mm":
            if dataset == "HCP":  # (80,80,80)
                img4d = ImgUtils.pad_4d_image_left(img4d, np.array([0,4,0,0]), [80,87,80,img4d.shape[3]], pad_value=0) # (80,87,80,none)
                return img4d[4:77,:,4:77, :] # (73, 87, 73, none)
            elif dataset == "HCP_2.5mm":  # (80,80,80)
                img4d = ImgUtils.pad_4d_image_left(img4d, np.array([0,4,0,0]), [80,87,80,img4d.shape[3]], pad_value=0)  # (80,87,80,none)
                return img4d[4:77,:,4:77,:]  # (73, 87, 73, none)
            elif dataset == "HCP_32g":  # ((80,80,80)
                img4d = ImgUtils.pad_4d_image_left(img4d, np.array([0, 4, 0, 0]), [80, 87, 80, img4d.shape[3]], pad_value=0)  # (80,87,80,none)
                return img4d[4:77, :, 4:77, :]  # (73, 87, 73, none)
            elif dataset == "TRACED":  # (80,80,80)
                img4d = ImgUtils.pad_4d_image_left(img4d, np.array([0,7,0,0]), [80,93,80,img4d.shape[3]],pad_value=0)  # (80,93,80,none)
                return img4d[1:79, :, 3:78, :]  # (78,93,75,none)
Ejemplo n.º 18
0
    def _create_prob_slices_file(HP, subjects, filename, bundle, shuffle=True):

        mask_dir = join(C.HOME, HP.DATASET_FOLDER)

        input_dir = HP.MULTI_PARENT_PATH

        combined_slices = []
        mask_slices = []

        for s in subjects:
            print("processing subject {}".format(s))

            probs_x = nib.load(
                join(input_dir, "UNet_x_" + str(HP.CV_FOLD), "probmaps",
                     s + "_probmap.nii.gz")).get_data()
            probs_y = nib.load(
                join(input_dir, "UNet_y_" + str(HP.CV_FOLD), "probmaps",
                     s + "_probmap.nii.gz")).get_data()
            probs_z = nib.load(
                join(input_dir, "UNet_z_" + str(HP.CV_FOLD), "probmaps",
                     s + "_probmap.nii.gz")).get_data()
            # probs_x = DatasetUtils.scale_input_to_unet_shape(probs_x, HP.DATASET, HP.RESOLUTION)
            # probs_y = DatasetUtils.scale_input_to_unet_shape(probs_y, HP.DATASET, HP.RESOLUTION)
            # probs_z = DatasetUtils.scale_input_to_unet_shape(probs_z, HP.DATASET, HP.RESOLUTION)
            combined = np.stack(
                (probs_x, probs_y, probs_z), axis=4
            )  # (73, 87, 73, 18, 3)  #not working alone: one dim too much for UNet -> reshape
            combined = np.reshape(
                combined,
                (combined.shape[0], combined.shape[1], combined.shape[2],
                 combined.shape[3] * combined.shape[4]))  # (73, 87, 73, 3*18)

            # print("combined shape after", combined.shape)

            mask_data = ImgUtils.create_multilabel_mask(
                HP, s, labels_type=HP.LABELS_TYPE)
            if HP.DATASET == "HCP_2mm":
                #use "HCP" because for mask we need downscaling
                mask_data = DatasetUtils.scale_input_to_unet_shape(
                    mask_data, "HCP", HP.RESOLUTION)
            elif HP.DATASET == "HCP_2.5mm":
                # use "HCP" because for mask we need downscaling
                mask_data = DatasetUtils.scale_input_to_unet_shape(
                    mask_data, "HCP", HP.RESOLUTION)
            else:
                # Mask has same resolution as probmaps -> we can use same resizing
                mask_data = DatasetUtils.scale_input_to_unet_shape(
                    mask_data, HP.DATASET, HP.RESOLUTION)

            # Save as Img
            img = nib.Nifti1Image(
                combined, ImgUtils.get_dwi_affine(HP.DATASET, HP.RESOLUTION))
            nib.save(
                img,
                join(HP.EXP_PATH, "combined", s + "_combinded_probmap.nii.gz"))

            combined = DatasetUtils.scale_input_to_unet_shape(
                combined, HP.DATASET, HP.RESOLUTION)
            assert (combined.shape[2] == mask_data.shape[2])

            #Save as Slices
            for z in range(combined.shape[2]):
                combined_slices.append(combined[:, :, z, :])
                mask_slices.append(mask_data[:, :, z, :])

        if shuffle:
            combined_slices, mask_slices = sk_shuffle(combined_slices,
                                                      mask_slices,
                                                      random_state=9)

        if HP.TRAIN:
            np.save(filename + "_data.npy", combined_slices)
            np.save(filename + "_seg.npy", mask_slices)
Ejemplo n.º 19
0
    def track(bundle, peaks, output_dir, brain_mask, filter_by_endpoints=False):
        '''

        :param bundle:   Bundle name
        :param output_dir:
        :param brain_mask:
        :param filter_by_endpoints:     use results of endings_segmentation to filter out all fibers not endings in those regions
        :return:
        '''
        tracking_folder = "TOM_trackings"
        # tracking_folder = "TOM_trackings_FiltEP6_FiltMask3"
        smooth = None       # None / 10
        TOM_folder = "TOM"      # TOM / TOM_thr1

        tmp_dir = tempfile.mkdtemp()
        os.system("export PATH=/code/mrtrix3/bin:$PATH")
        os.system("mkdir -p " + output_dir + "/" + tracking_folder)

        if filter_by_endpoints:
            beginnings_mask_ok = nib.load(output_dir + "/endings_segmentations/" + bundle + "_b.nii.gz").get_data().max() > 0
            endings_mask_ok = nib.load(output_dir + "/endings_segmentations/" + bundle + "_e.nii.gz").get_data().max() > 0

            if not beginnings_mask_ok:
                print("WARNING: tract beginnings mask of {} empty".format(bundle))

            if not endings_mask_ok:
                print("WARNING: tract endings mask of {} empty".format(bundle))

        if filter_by_endpoints and beginnings_mask_ok and endings_mask_ok:
            # dilation = 2    # dilation has to be quite high, because endings sometimes almost completely missing
            ImgUtils.dilate_binary_mask(output_dir + "/bundle_segmentations/" + bundle + ".nii.gz",
                                        tmp_dir + "/" + bundle + ".nii.gz", dilation=3)
            ImgUtils.dilate_binary_mask(output_dir + "/endings_segmentations/" + bundle + "_e.nii.gz",
                                        tmp_dir + "/" + bundle + "_e.nii.gz", dilation=6)
            ImgUtils.dilate_binary_mask(output_dir + "/endings_segmentations/" + bundle + "_b.nii.gz",
                                        tmp_dir + "/" + bundle + "_b.nii.gz", dilation=6)

            os.system("tckgen -algorithm FACT " +
                      output_dir + "/" + TOM_folder + "/" + bundle + ".nii.gz " +
                      output_dir + "/" + tracking_folder + "/" + bundle + ".tck" +
                      " -seed_image " + brain_mask +
                      " -mask " + tmp_dir + "/" + bundle + ".nii.gz" +
                      " -include " + tmp_dir + "/" + bundle + "_b.nii.gz" +
                      " -include " + tmp_dir + "/" + bundle + "_e.nii.gz" +
                      " -minlength 40 -select 2000 -force -quiet")

            # #Probabilistic Tracking without TOM
            # os.system("tckgen -algorithm iFOD2 " +
            #           peaks + " " +
            #           output_dir + "/" + tracking_folder + "/" + bundle + ".tck" +
            #           " -seed_image " + tmp_dir + "/" + bundle + ".nii.gz" +
            #           " -mask " + tmp_dir + "/" + bundle + ".nii.gz" +
            #           " -include " + tmp_dir + "/" + bundle + "_b.nii.gz" +
            #           " -include " + tmp_dir + "/" + bundle + "_e.nii.gz" +
            #           " -minlength 40 -seeds 200000 -select 2000 -force")
        else:
            os.system("tckgen -algorithm FACT " +
                      output_dir + "/" + TOM_folder + "/" + bundle + ".nii.gz " +
                      output_dir + "/" + tracking_folder + "/" + bundle + ".tck" +
                      " -seed_image " + brain_mask +
                      " -minlength 40 -select 2000 -force -quiet")

        reference_affine  = nib.load(brain_mask).get_affine()
        FiberUtils.convert_tck_to_trk(output_dir + "/" + tracking_folder + "/" + bundle + ".tck",
                                      output_dir + "/" + tracking_folder + "/" + bundle + ".trk",
                                      reference_affine, smooth=smooth)
        os.system("rm -f " + output_dir + "/" + tracking_folder + "/" + bundle + ".tck")
        shutil.rmtree(tmp_dir)
Ejemplo n.º 20
0
    def precompute_batches(custom_type=None):
        '''
        9000 slices per epoch -> 200 batches (batchsize=44) per epoch
        => 200-1000 batches needed


        270g_125mm_bundle_peaks_Y: no DAug, no Norm, only Y
        All_sizes_DAug_XYZ: 12g, 90g, 270g, DAug (no rotation, no elastic deform), Norm, XYZ
        270g_125mm_bundle_peaks_XYZ: no DAug, Norm, XYZ
        '''
        class HP:
            NORMALIZE_DATA = True
            DATA_AUGMENTATION = False
            CV_FOLD = 0
            INPUT_DIM = (144, 144)
            BATCH_SIZE = 44
            DATASET_FOLDER = "HCP"
            TYPE = "single_direction"
            EXP_PATH = "~"
            LABELS_FILENAME = "bundle_peaks"
            FEATURES_FILENAME = "270g_125mm_peaks"
            DATASET = "HCP"
            RESOLUTION = "1.25mm"
            LABELS_TYPE = np.float32

        HP.TRAIN_SUBJECTS, HP.VALIDATE_SUBJECTS, HP.TEST_SUBJECTS = ExpUtils.get_cv_fold(
            HP.CV_FOLD)

        num_batches_base = 5000
        num_batches = {
            "train": num_batches_base,
            "validate": int(num_batches_base / 3.),
            "test": int(num_batches_base / 3.),
        }

        if custom_type is None:
            types = ["train", "validate", "test"]
        else:
            types = [custom_type]

        for type in types:
            dataManager = DataManagerTrainingNiftiImgs(HP)
            batch_gen = dataManager.get_batches(
                batch_size=HP.BATCH_SIZE,
                type=type,
                subjects=getattr(HP,
                                 type.upper() + "_SUBJECTS"),
                num_batches=num_batches[type])

            for idx, batch in enumerate(batch_gen):
                print("Processing: {}".format(idx))

                # DATASET_DIR = "HCP_batches/270g_125mm_bundle_peaks_Y"
                # DATASET_DIR = "HCP_batches/All_sizes_DAug_XYZ"
                DATASET_DIR = "HCP_batches/270g_125mm_bundle_peaks_XYZ"
                ExpUtils.make_dir(join(C.HOME, DATASET_DIR, type))

                data = nib.Nifti1Image(
                    batch["data"],
                    ImgUtils.get_dwi_affine(HP.DATASET, HP.RESOLUTION))
                nib.save(
                    data,
                    join(C.HOME, DATASET_DIR, type,
                         "batch_" + str(idx) + "_data.nii.gz"))

                seg = nib.Nifti1Image(
                    batch["seg"],
                    ImgUtils.get_dwi_affine(HP.DATASET, HP.RESOLUTION))
                nib.save(
                    seg,
                    join(C.HOME, DATASET_DIR, type,
                         "batch_" + str(idx) + "_seg.nii.gz"))
Ejemplo n.º 21
0
    def track(bundle,
              peaks,
              output_dir,
              brain_mask,
              filter_by_endpoints=False,
              output_format="trk"):
        '''

        :param bundle:   Bundle name
        :param output_dir:
        :param brain_mask:
        :param filter_by_endpoints:     use results of endings_segmentation to filter out all fibers not endings in those regions
        :return:
        '''
        tracking_folder = "TOM_trackings"
        # tracking_folder = "TOM_trackings_FiltEP6_FiltMask3"
        smooth = None  # None / 10
        TOM_folder = "TOM"  # TOM / TOM_thr1

        tmp_dir = tempfile.mkdtemp()
        os.system("export PATH=/code/mrtrix3/bin:$PATH")
        os.system("mkdir -p " + output_dir + "/" + tracking_folder)

        if filter_by_endpoints:
            bundle_mask_ok = nib.load(output_dir + "/bundle_segmentations/" +
                                      bundle + ".nii.gz").get_data().max() > 0
            beginnings_mask_ok = nib.load(output_dir +
                                          "/endings_segmentations/" + bundle +
                                          "_b.nii.gz").get_data().max() > 0
            endings_mask_ok = nib.load(output_dir + "/endings_segmentations/" +
                                       bundle +
                                       "_e.nii.gz").get_data().max() > 0

            if not bundle_mask_ok:
                print(
                    "WARNING: tract mask of {} empty. Falling back to tracking without filtering by endpoints."
                    .format(bundle))

            if not beginnings_mask_ok:
                print(
                    "WARNING: tract beginnings mask of {} empty. Falling back to tracking without filtering by endpoints."
                    .format(bundle))

            if not endings_mask_ok:
                print(
                    "WARNING: tract endings mask of {} empty. Falling back to tracking without filtering by endpoints."
                    .format(bundle))

        if filter_by_endpoints and bundle_mask_ok and beginnings_mask_ok and endings_mask_ok:
            # dilation has to be quite high, because endings sometimes almost completely missing
            ImgUtils.dilate_binary_mask(output_dir + "/bundle_segmentations/" +
                                        bundle + ".nii.gz",
                                        tmp_dir + "/" + bundle + ".nii.gz",
                                        dilation=3)
            ImgUtils.dilate_binary_mask(
                output_dir + "/endings_segmentations/" + bundle + "_e.nii.gz",
                tmp_dir + "/" + bundle + "_e.nii.gz",
                dilation=6)
            ImgUtils.dilate_binary_mask(
                output_dir + "/endings_segmentations/" + bundle + "_b.nii.gz",
                tmp_dir + "/" + bundle + "_b.nii.gz",
                dilation=6)

            os.system("tckgen -algorithm FACT " + output_dir + "/" +
                      TOM_folder + "/" + bundle + ".nii.gz " + output_dir +
                      "/" + tracking_folder + "/" + bundle + ".tck" +
                      " -seed_image " + brain_mask + " -mask " + tmp_dir +
                      "/" + bundle + ".nii.gz" + " -include " + tmp_dir + "/" +
                      bundle + "_b.nii.gz" + " -include " + tmp_dir + "/" +
                      bundle + "_e.nii.gz" +
                      " -minlength 40 -select 2000 -force -quiet")

            # #Probabilistic Tracking without TOM
            # os.system("tckgen -algorithm iFOD2 " +
            #           peaks + " " +
            #           output_dir + "/" + tracking_folder + "/" + bundle + ".tck" +
            #           " -seed_image " + tmp_dir + "/" + bundle + ".nii.gz" +
            #           " -mask " + tmp_dir + "/" + bundle + ".nii.gz" +
            #           " -include " + tmp_dir + "/" + bundle + "_b.nii.gz" +
            #           " -include " + tmp_dir + "/" + bundle + "_e.nii.gz" +
            #           " -minlength 40 -seeds 200000 -select 2000 -force")
        else:
            os.system("tckgen -algorithm FACT " + output_dir + "/" +
                      TOM_folder + "/" + bundle + ".nii.gz " + output_dir +
                      "/" + tracking_folder + "/" + bundle + ".tck" +
                      " -seed_image " + brain_mask +
                      " -minlength 40 -select 2000 -force -quiet")

        if output_format == "trk":
            reference_affine = nib.load(brain_mask).get_affine()
            FiberUtils.convert_tck_to_trk(
                output_dir + "/" + tracking_folder + "/" + bundle + ".tck",
                output_dir + "/" + tracking_folder + "/" + bundle + ".trk",
                reference_affine,
                smooth=smooth)
            os.system("rm -f " + output_dir + "/" + tracking_folder + "/" +
                      bundle + ".tck")
        shutil.rmtree(tmp_dir)
Ejemplo n.º 22
0
def run_tractseg(data, output_type="tract_segmentation", input_type="peaks",
                 single_orientation=False, verbose=False, dropout_sampling=False, threshold=0.5,
                 bundle_specific_threshold=False, get_probs=False, peak_threshold=0.1):
    '''
    Run TractSeg

    :param data: input peaks (4D numpy array with shape [x,y,z,9])
    :param output_type: "tract_segmentation" | "endings_segmentation" | "TOM" | "dm_regression"
    :param input_type: "peaks"
    :param verbose: show debugging infos
    :param dropout_sampling: create uncertainty map by monte carlo dropout (https://arxiv.org/abs/1506.02142)
    :param threshold: Threshold for converting probability map to binary map
    :param bundle_specific_threshold: Threshold is lower for some bundles which need more sensitivity (CA, CST, FX)
    :param get_probs: Output raw probability map instead of binary map
    :param peak_threshold: all peaks shorter than peak_threshold will be set to zero
    :return: 4D numpy array with the output of tractseg
        for tract_segmentation:     [x,y,z,nr_of_bundles]
        for endings_segmentation:   [x,y,z,2*nr_of_bundles]
        for TOM:                    [x,y,z,3*nr_of_bundles]
    '''
    start_time = time.time()

    config = get_config_name(input_type, output_type)
    HP = getattr(importlib.import_module("tractseg.config.PretrainedModels." + config), "HP")()
    HP.VERBOSE = verbose
    HP.TRAIN = False
    HP.TEST = False
    HP.SEGMENT = False
    HP.GET_PROBS = get_probs
    HP.LOAD_WEIGHTS = True
    HP.DROPOUT_SAMPLING = dropout_sampling
    HP.THRESHOLD = threshold

    if bundle_specific_threshold:
        HP.GET_PROBS = True

    if input_type == "peaks":
        if HP.EXPERIMENT_TYPE == "tract_segmentation" and HP.DROPOUT_SAMPLING:
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_dropout_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_12g90g270g_125mm_DAugAll_Dropout", "best_weights_ep114.npz")
        elif HP.EXPERIMENT_TYPE == "tract_segmentation":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "x_Pretrained_TractSeg_Models/TractSeg_T1_12g90g270g_125mm_DAugAll", "best_weights_ep392.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg72_888", "best_weights_ep247.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg72_888_SchizoFineT_lr001", "best_weights_ep186.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_12g90g270g_125mm_DS_DAugAll_RotMir", "best_weights_ep200.npz")
        elif HP.EXPERIMENT_TYPE == "endings_segmentation":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_endings_segmentation_v3.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "EndingsSeg_12g90g270g_125mm_DS_DAugAll", "best_weights_ep234.npz")
        elif HP.EXPERIMENT_TYPE == "peak_regression":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_peak_regression_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "x_Pretrained_TractSeg_Models/Peaks20_12g90g270g_125mm_DAugSimp_constW5", "best_weights_ep441.npz")  #more oversegmentation with DAug
        elif HP.EXPERIMENT_TYPE == "dm_regression":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_dm_regression_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "DmReg_12g90g270g_125mm_DAugAll_Ubuntu", "best_weights_ep80.npz")
    elif input_type == "T1":
        if HP.EXPERIMENT_TYPE == "tract_segmentation":
            # HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_v1.npz")
            HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes/x_Pretrained_TractSeg_Models", "TractSeg_T1_125mm_DAugAll", "best_weights_ep142.npz")
        elif HP.EXPERIMENT_TYPE == "endings_segmentation":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_endings_segmentation_v1.npz")
        elif HP.EXPERIMENT_TYPE == "peak_regression":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_peak_regression_v1.npz")
    print("Loading weights from: {}".format(HP.WEIGHTS_PATH))

    if HP.EXPERIMENT_TYPE == "peak_regression":
        HP.NR_OF_CLASSES = 3*len(ExpUtils.get_bundle_names(HP.CLASSES)[1:])
    else:
        HP.NR_OF_CLASSES = len(ExpUtils.get_bundle_names(HP.CLASSES)[1:])

    if HP.VERBOSE:
        print("Hyperparameters:")
        ExpUtils.print_HPs(HP)

    Utils.download_pretrained_weights(experiment_type=HP.EXPERIMENT_TYPE, dropout_sampling=HP.DROPOUT_SAMPLING)

    data = np.nan_to_num(data)
    # brain_mask = ImgUtils.simple_brain_mask(data)
    # if HP.VERBOSE:
    #     nib.save(nib.Nifti1Image(brain_mask, np.eye(4)), "otsu_brain_mask_DEBUG.nii.gz")

    if input_type == "T1":
        data = np.reshape(data, (data.shape[0], data.shape[1], data.shape[2], 1))
    data, seg_None, bbox, original_shape = DatasetUtils.crop_to_nonzero(data)
    data, transformation = DatasetUtils.pad_and_scale_img_to_square_img(data, target_size=HP.INPUT_DIM[0])

    model = BaseModel(HP)

    if HP.EXPERIMENT_TYPE == "tract_segmentation" or HP.EXPERIMENT_TYPE == "endings_segmentation" or HP.EXPERIMENT_TYPE == "dm_regression":
        if single_orientation:     # mainly needed for testing because of less RAM requirements
            dataManagerSingle = DataManagerSingleSubjectByFile(HP, data=data)
            trainerSingle = Trainer(model, dataManagerSingle)
            if HP.DROPOUT_SAMPLING or HP.EXPERIMENT_TYPE == "dm_regression" or HP.GET_PROBS:
                seg, img_y = trainerSingle.get_seg_single_img(HP, probs=True, scale_to_world_shape=False, only_prediction=True)
            else:
                seg, img_y = trainerSingle.get_seg_single_img(HP, probs=False, scale_to_world_shape=False, only_prediction=True)
        else:
            seg_xyz, gt = DirectionMerger.get_seg_single_img_3_directions(HP, model, data=data, scale_to_world_shape=False, only_prediction=True)
            if HP.DROPOUT_SAMPLING or HP.EXPERIMENT_TYPE == "dm_regression" or HP.GET_PROBS:
                seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=True)
            else:
                seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=False)

    elif HP.EXPERIMENT_TYPE == "peak_regression":
        dataManagerSingle = DataManagerSingleSubjectByFile(HP, data=data)
        trainerSingle = Trainer(model, dataManagerSingle)
        seg, img_y = trainerSingle.get_seg_single_img(HP, probs=True, scale_to_world_shape=False, only_prediction=True)
        if bundle_specific_threshold:
            seg = ImgUtils.remove_small_peaks_bundle_specific(seg, ExpUtils.get_bundle_names(HP.CLASSES)[1:], len_thr=0.3)
        else:
            seg = ImgUtils.remove_small_peaks(seg, len_thr=peak_threshold)
        #3 dir for Peaks -> not working (?)
        # seg_xyz, gt = DirectionMerger.get_seg_single_img_3_directions(HP, model, data=data, scale_to_world_shape=False, only_prediction=True)
        # seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=True)

    if bundle_specific_threshold and HP.EXPERIMENT_TYPE == "tract_segmentation":
        seg = ImgUtils.probs_to_binary_bundle_specific(seg, ExpUtils.get_bundle_names(HP.CLASSES)[1:])

    #remove following two lines to keep super resolution
    seg = DatasetUtils.cut_and_scale_img_back_to_original_img(seg, transformation)
    seg = DatasetUtils.add_original_zero_padding_again(seg, bbox, original_shape, HP.NR_OF_CLASSES)
    ExpUtils.print_verbose(HP, "Took {}s".format(round(time.time() - start_time, 2)))

    return seg
Ejemplo n.º 23
0
    def _create_slices_file(HP, subjects, filename, slice, shuffle=True):
        data_dir = join(C.HOME, HP.DATASET_FOLDER)

        dwi_slices = []
        mask_slices = []

        print("\n\nProcessing Data...")
        for s in subjects:
            print("processing dwi subject {}".format(s))

            dwi = nib.load(join(data_dir, s, HP.FEATURES_FILENAME + ".nii.gz"))
            dwi_data = dwi.get_data()
            dwi_data = np.nan_to_num(dwi_data)
            dwi_data = DatasetUtils.scale_input_to_unet_shape(dwi_data, HP.DATASET, HP.RESOLUTION)

            # if slice == "x":
            #     for z in range(dwi_data.shape[0]):
            #         dwi_slices.append(dwi_data[z, :, :, :])
            #
            # if slice == "y":
            #     for z in range(dwi_data.shape[1]):
            #         dwi_slices.append(dwi_data[:, z, :, :])
            #
            # if slice == "z":
            #     for z in range(dwi_data.shape[2]):
            #         dwi_slices.append(dwi_data[:, :, z, :])

            #Use slices from all directions in one dataset
            for z in range(dwi_data.shape[0]):
                dwi_slices.append(dwi_data[z, :, :, :])
            for z in range(dwi_data.shape[1]):
                dwi_slices.append(dwi_data[:, z, :, :])
            for z in range(dwi_data.shape[2]):
                dwi_slices.append(dwi_data[:, :, z, :])

        dwi_slices = np.array(dwi_slices)
        random_idxs = None
        if shuffle:
            random_idxs = np.random.choice(len(dwi_slices), len(dwi_slices))
            dwi_slices = dwi_slices[random_idxs]

        np.save(filename + "_data.npy", dwi_slices)
        del dwi_slices  #free memory


        print("\n\nProcessing Segs...")
        for s in subjects:
            print("processing seg subject {}".format(s))

            mask_data = ImgUtils.create_multilabel_mask(HP, s, labels_type=HP.LABELS_TYPE)
            if HP.RESOLUTION == "2.5mm":
                mask_data = ImgUtils.resize_first_three_dims(mask_data, order=0, zoom=0.5)
            mask_data = DatasetUtils.scale_input_to_unet_shape(mask_data, HP.DATASET, HP.RESOLUTION)

            # if slice == "x":
            #     for z in range(dwi_data.shape[0]):
            #         mask_slices.append(mask_data[z, :, :, :])
            #
            # if slice == "y":
            #     for z in range(dwi_data.shape[1]):
            #         mask_slices.append(mask_data[:, z, :, :])
            #
            # if slice == "z":
            #     for z in range(dwi_data.shape[2]):
            #         mask_slices.append(mask_data[:, :, z, :])

            # Use slices from all directions in one dataset
            for z in range(dwi_data.shape[0]):
                mask_slices.append(mask_data[z, :, :, :])
            for z in range(dwi_data.shape[1]):
                mask_slices.append(mask_data[:, z, :, :])
            for z in range(dwi_data.shape[2]):
                mask_slices.append(mask_data[:, :, z, :])

        mask_slices = np.array(mask_slices)
        print("SEG TYPE: {}".format(mask_slices.dtype))
        if shuffle:
            mask_slices = mask_slices[random_idxs]

        np.save(filename + "_seg.npy", mask_slices)
Ejemplo n.º 24
0
    def _create_slices_file(HP, subjects, filename, slice, shuffle=True):
        data_dir = join(C.HOME, HP.DATASET_FOLDER)

        dwi_slices = []
        mask_slices = []

        print("\n\nProcessing Data...")
        for s in subjects:
            print("processing dwi subject {}".format(s))

            dwi = nib.load(join(data_dir, s, HP.FEATURES_FILENAME + ".nii.gz"))
            dwi_data = dwi.get_data()
            dwi_data = np.nan_to_num(dwi_data)
            dwi_data = DatasetUtils.scale_input_to_unet_shape(
                dwi_data, HP.DATASET, HP.RESOLUTION)

            # if slice == "x":
            #     for z in range(dwi_data.shape[0]):
            #         dwi_slices.append(dwi_data[z, :, :, :])
            #
            # if slice == "y":
            #     for z in range(dwi_data.shape[1]):
            #         dwi_slices.append(dwi_data[:, z, :, :])
            #
            # if slice == "z":
            #     for z in range(dwi_data.shape[2]):
            #         dwi_slices.append(dwi_data[:, :, z, :])

            #Use slices from all directions in one dataset
            for z in range(dwi_data.shape[0]):
                dwi_slices.append(dwi_data[z, :, :, :])
            for z in range(dwi_data.shape[1]):
                dwi_slices.append(dwi_data[:, z, :, :])
            for z in range(dwi_data.shape[2]):
                dwi_slices.append(dwi_data[:, :, z, :])

        dwi_slices = np.array(dwi_slices)
        random_idxs = None
        if shuffle:
            random_idxs = np.random.choice(len(dwi_slices), len(dwi_slices))
            dwi_slices = dwi_slices[random_idxs]

        np.save(filename + "_data.npy", dwi_slices)
        del dwi_slices  #free memory

        print("\n\nProcessing Segs...")
        for s in subjects:
            print("processing seg subject {}".format(s))

            mask_data = ImgUtils.create_multilabel_mask(
                HP, s, labels_type=HP.LABELS_TYPE)
            if HP.RESOLUTION == "2.5mm":
                mask_data = ImgUtils.resize_first_three_dims(mask_data,
                                                             order=0,
                                                             zoom=0.5)
            mask_data = DatasetUtils.scale_input_to_unet_shape(
                mask_data, HP.DATASET, HP.RESOLUTION)

            # if slice == "x":
            #     for z in range(dwi_data.shape[0]):
            #         mask_slices.append(mask_data[z, :, :, :])
            #
            # if slice == "y":
            #     for z in range(dwi_data.shape[1]):
            #         mask_slices.append(mask_data[:, z, :, :])
            #
            # if slice == "z":
            #     for z in range(dwi_data.shape[2]):
            #         mask_slices.append(mask_data[:, :, z, :])

            # Use slices from all directions in one dataset
            for z in range(dwi_data.shape[0]):
                mask_slices.append(mask_data[z, :, :, :])
            for z in range(dwi_data.shape[1]):
                mask_slices.append(mask_data[:, z, :, :])
            for z in range(dwi_data.shape[2]):
                mask_slices.append(mask_data[:, :, z, :])

        mask_slices = np.array(mask_slices)
        print("SEG TYPE: {}".format(mask_slices.dtype))
        if shuffle:
            mask_slices = mask_slices[random_idxs]

        np.save(filename + "_seg.npy", mask_slices)