Ejemplo n.º 1
0
    def create_one_3D_file():
        '''
        Create one big file which contains all 3D Images (not slices).
        '''

        class HP:
            DATASET = "HCP"
            RESOLUTION = "1.25mm"
            FEATURES_FILENAME = "270g_125mm_peaks"
            LABELS_TYPE = np.int16
            DATASET_FOLDER = "HCP"

        data_all = []
        seg_all = []

        print("\n\nProcessing Data...")
        for s in get_all_subjects():
            print("processing data subject {}".format(s))
            data = nib.load(join(C.HOME, HP.DATASET_FOLDER, s, HP.FEATURES_FILENAME + ".nii.gz")).get_data()
            data = np.nan_to_num(data)
            data = DatasetUtils.scale_input_to_unet_shape(data, HP.DATASET, HP.RESOLUTION)
        data_all.append(np.array(data))
        np.save("data.npy", data_all)
        del data_all  # free memory

        print("\n\nProcessing Segs...")
        for s in get_all_subjects():
            print("processing seg subject {}".format(s))
            seg = ImgUtils.create_multilabel_mask(HP, s, labels_type=HP.LABELS_TYPE)
            if HP.RESOLUTION == "2.5mm":
                seg = ImgUtils.resize_first_three_dims(seg, order=0, zoom=0.5)
            seg = DatasetUtils.scale_input_to_unet_shape(seg, HP.DATASET, HP.RESOLUTION)
        seg_all.append(np.array(seg))
        print("SEG TYPE: {}".format(seg_all.dtype))
        np.save("seg.npy", seg_all)
Ejemplo n.º 2
0
    def save_fusion_nifti_as_npy():

        #Can leave this always the same (for 270g and 32g)
        class HP:
            DATASET = "HCP"
            RESOLUTION = "1.25mm"
            FEATURES_FILENAME = "270g_125mm_peaks"
            LABELS_TYPE = np.int16
            LABELS_FILENAME = "bundle_masks"
            DATASET_FOLDER = "HCP"

        #change this for 270g and 32g
        DIFFUSION_FOLDER = "32g_25mm"

        subjects = get_all_subjects()
        # fold0 = ['687163', '685058', '683256', '680957', '679568', '677968', '673455', '672756', '665254', '654754', '645551', '644044', '638049', '627549', '623844', '622236', '620434', '613538', '601127', '599671', '599469']
        # fold1 = ['992774', '991267', '987983', '984472', '983773', '979984', '978578', '965771', '965367', '959574', '958976', '957974', '951457', '932554', '930449', '922854', '917255', '912447', '910241', '907656', '904044']
        # fold2 = ['901442', '901139', '901038', '899885', '898176', '896879', '896778', '894673', '889579', '887373', '877269', '877168', '872764', '872158', '871964', '871762', '865363', '861456', '859671', '857263', '856766']
        # fold3 = ['849971', '845458', '837964', '837560', '833249', '833148', '826454', '826353', '816653', '814649', '802844', '792766', '792564', '789373', '786569', '784565', '782561', '779370', '771354', '770352', '765056']
        # fold4 = ['761957', '759869', '756055', '753251', '751348', '749361', '748662', '748258', '742549', '734045', '732243', '729557', '729254', '715647', '715041', '709551', '705341', '704238', '702133', '695768', '690152']
        # subjects = fold2 + fold3 + fold4

        # subjects = ['654754', '645551', '644044', '638049', '627549', '623844', '622236', '620434', '613538', '601127', '599671', '599469']

        print("\n\nProcessing Data...")
        for s in subjects:
            print("processing data subject {}".format(s))
            start_time = time.time()
            data = nib.load(
                join(C.NETWORK_DRIVE, "HCP_fusion_" + DIFFUSION_FOLDER,
                     s + "_probmap.nii.gz")).get_data()
            print("Done Loading")
            data = np.nan_to_num(data)
            data = DatasetUtils.scale_input_to_unet_shape(
                data, HP.DATASET, HP.RESOLUTION)
            data = data[:-1, :, :
                        -1, :]  # cut one pixel at the end, because in scale_input_to_world_shape we ouputted 146 -> one too much at the end
            ExpUtils.make_dir(
                join(C.NETWORK_DRIVE, "HCP_fusion_npy_" + DIFFUSION_FOLDER, s))
            np.save(
                join(C.NETWORK_DRIVE, "HCP_fusion_npy_" + DIFFUSION_FOLDER, s,
                     DIFFUSION_FOLDER + "_xyz.npy"), data)
            print("Took {}s".format(time.time() - start_time))

            print("processing seg subject {}".format(s))
            start_time = time.time()
            # seg = ImgUtils.create_multilabel_mask(HP, s, labels_type=HP.LABELS_TYPE)
            seg = nib.load(
                join(C.NETWORK_DRIVE, "HCP_for_training_COPY", s,
                     HP.LABELS_FILENAME + ".nii.gz")).get_data()
            if HP.RESOLUTION == "2.5mm":
                seg = ImgUtils.resize_first_three_dims(seg, order=0, zoom=0.5)
            seg = DatasetUtils.scale_input_to_unet_shape(
                seg, HP.DATASET, HP.RESOLUTION)
            np.save(
                join(C.NETWORK_DRIVE, "HCP_fusion_npy_" + DIFFUSION_FOLDER, s,
                     "bundle_masks.npy"), seg)
            print("Took {}s".format(time.time() - start_time))
Ejemplo n.º 3
0
    def _create_prob_slices_file(HP, subjects, filename, bundle, shuffle=True):

        mask_dir = join(C.HOME, HP.DATASET_FOLDER)

        input_dir = HP.MULTI_PARENT_PATH

        combined_slices = []
        mask_slices = []

        for s in subjects:
            print("processing subject {}".format(s))

            probs_x = nib.load(join(input_dir, "UNet_x_" + str(HP.CV_FOLD), "probmaps", s + "_probmap.nii.gz")).get_data()
            probs_y = nib.load(join(input_dir, "UNet_y_" + str(HP.CV_FOLD), "probmaps", s + "_probmap.nii.gz")).get_data()
            probs_z = nib.load(join(input_dir, "UNet_z_" + str(HP.CV_FOLD), "probmaps", s + "_probmap.nii.gz")).get_data()
            # probs_x = DatasetUtils.scale_input_to_unet_shape(probs_x, HP.DATASET, HP.RESOLUTION)
            # probs_y = DatasetUtils.scale_input_to_unet_shape(probs_y, HP.DATASET, HP.RESOLUTION)
            # probs_z = DatasetUtils.scale_input_to_unet_shape(probs_z, HP.DATASET, HP.RESOLUTION)
            combined = np.stack((probs_x, probs_y, probs_z), axis=4)  # (73, 87, 73, 18, 3)  #not working alone: one dim too much for UNet -> reshape
            combined = np.reshape(combined, (combined.shape[0], combined.shape[1], combined.shape[2],
                                             combined.shape[3] * combined.shape[4]))    # (73, 87, 73, 3*18)

            # print("combined shape after", combined.shape)

            mask_data = ImgUtils.create_multilabel_mask(HP, s, labels_type=HP.LABELS_TYPE)
            if HP.DATASET == "HCP_2mm":
                #use "HCP" because for mask we need downscaling
                mask_data = DatasetUtils.scale_input_to_unet_shape(mask_data, "HCP", HP.RESOLUTION)
            elif HP.DATASET == "HCP_2.5mm":
                # use "HCP" because for mask we need downscaling
                mask_data = DatasetUtils.scale_input_to_unet_shape(mask_data, "HCP", HP.RESOLUTION)
            else:
                # Mask has same resolution as probmaps -> we can use same resizing
                mask_data = DatasetUtils.scale_input_to_unet_shape(mask_data, HP.DATASET, HP.RESOLUTION)

            # Save as Img
            img = nib.Nifti1Image(combined, ImgUtils.get_dwi_affine(HP.DATASET, HP.RESOLUTION))
            nib.save(img, join(HP.EXP_PATH, "combined", s + "_combinded_probmap.nii.gz"))


            combined = DatasetUtils.scale_input_to_unet_shape(combined, HP.DATASET, HP.RESOLUTION)
            assert (combined.shape[2] == mask_data.shape[2])

            #Save as Slices
            for z in range(combined.shape[2]):
                combined_slices.append(combined[:, :, z, :])
                mask_slices.append(mask_data[:, :, z, :])

        if shuffle:
            combined_slices, mask_slices = sk_shuffle(combined_slices, mask_slices, random_state=9)

        if HP.TRAIN:
            np.save(filename + "_data.npy", combined_slices)
            np.save(filename + "_seg.npy", mask_slices)
Ejemplo n.º 4
0
    def save_fusion_nifti_as_npy():

        #Can leave this always the same (for 270g and 32g)
        class HP:
            DATASET = "HCP"
            RESOLUTION = "1.25mm"
            FEATURES_FILENAME = "270g_125mm_peaks"
            LABELS_TYPE = np.int16
            LABELS_FILENAME = "bundle_masks"
            DATASET_FOLDER = "HCP"

        #change this for 270g and 32g
        DIFFUSION_FOLDER = "32g_25mm"

        subjects = get_all_subjects()
        # fold0 = ['687163', '685058', '683256', '680957', '679568', '677968', '673455', '672756', '665254', '654754', '645551', '644044', '638049', '627549', '623844', '622236', '620434', '613538', '601127', '599671', '599469']
        # fold1 = ['992774', '991267', '987983', '984472', '983773', '979984', '978578', '965771', '965367', '959574', '958976', '957974', '951457', '932554', '930449', '922854', '917255', '912447', '910241', '907656', '904044']
        # fold2 = ['901442', '901139', '901038', '899885', '898176', '896879', '896778', '894673', '889579', '887373', '877269', '877168', '872764', '872158', '871964', '871762', '865363', '861456', '859671', '857263', '856766']
        # fold3 = ['849971', '845458', '837964', '837560', '833249', '833148', '826454', '826353', '816653', '814649', '802844', '792766', '792564', '789373', '786569', '784565', '782561', '779370', '771354', '770352', '765056']
        # fold4 = ['761957', '759869', '756055', '753251', '751348', '749361', '748662', '748258', '742549', '734045', '732243', '729557', '729254', '715647', '715041', '709551', '705341', '704238', '702133', '695768', '690152']
        # subjects = fold2 + fold3 + fold4

        # subjects = ['654754', '645551', '644044', '638049', '627549', '623844', '622236', '620434', '613538', '601127', '599671', '599469']

        print("\n\nProcessing Data...")
        for s in subjects:
            print("processing data subject {}".format(s))
            start_time = time.time()
            data = nib.load(join(C.NETWORK_DRIVE, "HCP_fusion_" + DIFFUSION_FOLDER, s + "_probmap.nii.gz")).get_data()
            print("Done Loading")
            data = np.nan_to_num(data)
            data = DatasetUtils.scale_input_to_unet_shape(data, HP.DATASET, HP.RESOLUTION)
            data = data[:-1, :, :-1, :]  # cut one pixel at the end, because in scale_input_to_world_shape we ouputted 146 -> one too much at the end
            ExpUtils.make_dir(join(C.NETWORK_DRIVE, "HCP_fusion_npy_" + DIFFUSION_FOLDER, s))
            np.save(join(C.NETWORK_DRIVE, "HCP_fusion_npy_" + DIFFUSION_FOLDER, s, DIFFUSION_FOLDER + "_xyz.npy"), data)
            print("Took {}s".format(time.time() - start_time))

            print("processing seg subject {}".format(s))
            start_time = time.time()
            # seg = ImgUtils.create_multilabel_mask(HP, s, labels_type=HP.LABELS_TYPE)
            seg = nib.load(join(C.NETWORK_DRIVE, "HCP_for_training_COPY", s, HP.LABELS_FILENAME + ".nii.gz")).get_data()
            if HP.RESOLUTION == "2.5mm":
                seg = ImgUtils.resize_first_three_dims(seg, order=0, zoom=0.5)
            seg = DatasetUtils.scale_input_to_unet_shape(seg, HP.DATASET, HP.RESOLUTION)
            np.save(join(C.NETWORK_DRIVE, "HCP_fusion_npy_" + DIFFUSION_FOLDER, s, "bundle_masks.npy"), seg)
            print("Took {}s".format(time.time() - start_time))
Ejemplo n.º 5
0
    def create_one_3D_file():
        '''
        Create one big file which contains all 3D Images (not slices).
        '''
        class HP:
            DATASET = "HCP"
            RESOLUTION = "1.25mm"
            FEATURES_FILENAME = "270g_125mm_peaks"
            LABELS_TYPE = np.int16
            DATASET_FOLDER = "HCP"

        data_all = []
        seg_all = []

        print("\n\nProcessing Data...")
        for s in get_all_subjects():
            print("processing data subject {}".format(s))
            data = nib.load(
                join(C.HOME, HP.DATASET_FOLDER, s,
                     HP.FEATURES_FILENAME + ".nii.gz")).get_data()
            data = np.nan_to_num(data)
            data = DatasetUtils.scale_input_to_unet_shape(
                data, HP.DATASET, HP.RESOLUTION)
        data_all.append(np.array(data))
        np.save("data.npy", data_all)
        del data_all  # free memory

        print("\n\nProcessing Segs...")
        for s in get_all_subjects():
            print("processing seg subject {}".format(s))
            seg = ImgUtils.create_multilabel_mask(HP,
                                                  s,
                                                  labels_type=HP.LABELS_TYPE)
            if HP.RESOLUTION == "2.5mm":
                seg = ImgUtils.resize_first_three_dims(seg, order=0, zoom=0.5)
            seg = DatasetUtils.scale_input_to_unet_shape(
                seg, HP.DATASET, HP.RESOLUTION)
        seg_all.append(np.array(seg))
        print("SEG TYPE: {}".format(seg_all.dtype))
        np.save("seg.npy", seg_all)
Ejemplo n.º 6
0
        def finalize_data(layers):
            layers = np.array(layers)

            # Get in right order (x,y,z) and
            if HP.SLICE_DIRECTION == "x":
                layers = layers.transpose(0, 1, 2, 3)

            elif HP.SLICE_DIRECTION == "y":
                layers = layers.transpose(1, 0, 2, 3)

            elif HP.SLICE_DIRECTION == "z":
                layers = layers.transpose(1, 2, 0, 3)

            if scale_to_world_shape:
                layers = DatasetUtils.scale_input_to_world_shape(layers, HP.DATASET, HP.RESOLUTION)

            return layers.astype(np.float32)
Ejemplo n.º 7
0
def run_tractseg(data, output_type="tract_segmentation", input_type="peaks",
                 single_orientation=False, verbose=False, dropout_sampling=False, threshold=0.5,
                 bundle_specific_threshold=False, get_probs=False):
    '''
    Run TractSeg

    :param data: input peaks (4D numpy array with shape [x,y,z,9])
    :param output_type: "tract_segmentation" | "endings_segmentation" | "TOM" | "dm_regression"
    :param input_type: "peaks"
    :param verbose: show debugging infos
    :param dropout_sampling: create uncertainty map by monte carlo dropout (https://arxiv.org/abs/1506.02142)
    :param threshold: Threshold for converting probability map to binary map
    :param bundle_specific_threshold: Threshold is lower for some bundles which need more sensitivity (CA, CST, FX)
    :param get_probs: Output raw probability map instead of binary map
    :return: 4D numpy array with the output of tractseg
        for tract_segmentation:     [x,y,z,nr_of_bundles]
        for endings_segmentation:   [x,y,z,2*nr_of_bundles]
        for TOM:                    [x,y,z,3*nr_of_bundles]
    '''
    start_time = time.time()

    config = get_config_name(input_type, output_type)
    HP = getattr(importlib.import_module("tractseg.config.PretrainedModels." + config), "HP")()
    HP.VERBOSE = verbose
    HP.TRAIN = False
    HP.TEST = False
    HP.SEGMENT = False
    HP.GET_PROBS = get_probs
    HP.LOAD_WEIGHTS = True
    HP.DROPOUT_SAMPLING = dropout_sampling
    HP.THRESHOLD = threshold

    if bundle_specific_threshold:
        HP.GET_PROBS = True

    if input_type == "peaks":
        if HP.EXPERIMENT_TYPE == "tract_segmentation" and HP.DROPOUT_SAMPLING:
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_dropout_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_12g90g270g_125mm_DAugAll_Dropout", "best_weights_ep114.npz")
        elif HP.EXPERIMENT_TYPE == "tract_segmentation":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_T1_12g90g270g_125mm_DAugAll", "best_weights_ep126.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg72_888", "best_weights_ep247.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg72_888_SchizoFineT_lr001", "best_weights_ep186.npz")
        elif HP.EXPERIMENT_TYPE == "endings_segmentation":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_endings_segmentation_v2.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "EndingsSeg_12g90g270g_125mm_DAugAll", "best_weights_ep16.npz")
        elif HP.EXPERIMENT_TYPE == "peak_regression":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_peak_regression_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "x_Pretrained_TractSeg_Models/Peaks20_12g90g270g_125mm_DAugSimp_constW5", "best_weights_ep441.npz")  #more oversegmentation with DAug
        elif HP.EXPERIMENT_TYPE == "dm_regression":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_dm_regression_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "DmReg_12g90g270g_125mm_DAugAll_Ubuntu", "best_weights_ep80.npz")
    elif input_type == "T1":
        if HP.EXPERIMENT_TYPE == "tract_segmentation":
            # HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_v1.npz")
            HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_T1_125mm_DAugAll", "best_weights_ep142.npz")
        elif HP.EXPERIMENT_TYPE == "endings_segmentation":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_endings_segmentation_v1.npz")
        elif HP.EXPERIMENT_TYPE == "peak_regression":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_peak_regression_v1.npz")
    print("Loading weights from: {}".format(HP.WEIGHTS_PATH))

    if HP.EXPERIMENT_TYPE == "peak_regression":
        HP.NR_OF_CLASSES = 3*len(ExpUtils.get_bundle_names(HP.CLASSES)[1:])
    else:
        HP.NR_OF_CLASSES = len(ExpUtils.get_bundle_names(HP.CLASSES)[1:])

    if HP.VERBOSE:
        print("Hyperparameters:")
        ExpUtils.print_HPs(HP)

    Utils.download_pretrained_weights(experiment_type=HP.EXPERIMENT_TYPE, dropout_sampling=HP.DROPOUT_SAMPLING)

    data = np.nan_to_num(data)
    # brain_mask = ImgUtils.simple_brain_mask(data)
    # if HP.VERBOSE:
    #     nib.save(nib.Nifti1Image(brain_mask, np.eye(4)), "otsu_brain_mask_DEBUG.nii.gz")

    if input_type == "T1":
        data = np.reshape(data, (data.shape[0], data.shape[1], data.shape[2], 1))
    data, seg_None, bbox, original_shape = DatasetUtils.crop_to_nonzero(data)
    data, transformation = DatasetUtils.pad_and_scale_img_to_square_img(data, target_size=HP.INPUT_DIM[0])

    model = BaseModel(HP)

    if HP.EXPERIMENT_TYPE == "tract_segmentation" or HP.EXPERIMENT_TYPE == "endings_segmentation" or HP.EXPERIMENT_TYPE == "dm_regression":
        if single_orientation:     # mainly needed for testing because of less RAM requirements
            dataManagerSingle = DataManagerSingleSubjectByFile(HP, data=data)
            trainerSingle = Trainer(model, dataManagerSingle)
            if HP.DROPOUT_SAMPLING or HP.EXPERIMENT_TYPE == "dm_regression" or HP.GET_PROBS:
                seg, img_y = trainerSingle.get_seg_single_img(HP, probs=True, scale_to_world_shape=False, only_prediction=True)
            else:
                seg, img_y = trainerSingle.get_seg_single_img(HP, probs=False, scale_to_world_shape=False, only_prediction=True)
        else:
            seg_xyz, gt = DirectionMerger.get_seg_single_img_3_directions(HP, model, data=data, scale_to_world_shape=False, only_prediction=True)
            if HP.DROPOUT_SAMPLING or HP.EXPERIMENT_TYPE == "dm_regression" or HP.GET_PROBS:
                seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=True)
            else:
                seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=False)

    elif HP.EXPERIMENT_TYPE == "peak_regression":
        dataManagerSingle = DataManagerSingleSubjectByFile(HP, data=data)
        trainerSingle = Trainer(model, dataManagerSingle)
        seg, img_y = trainerSingle.get_seg_single_img(HP, probs=True, scale_to_world_shape=False, only_prediction=True)
        if bundle_specific_threshold:
            seg = ImgUtils.remove_small_peaks_bundle_specific(seg, ExpUtils.get_bundle_names(HP.CLASSES)[1:], len_thr=0.3)
        else:
            seg = ImgUtils.remove_small_peaks(seg, len_thr=0.3)  # set lower for more sensitivity
        #3 dir for Peaks -> not working (?)
        # seg_xyz, gt = DirectionMerger.get_seg_single_img_3_directions(HP, model, data=data, scale_to_world_shape=False, only_prediction=True)
        # seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=True)

    if bundle_specific_threshold and HP.EXPERIMENT_TYPE == "tract_segmentation":
        seg = ImgUtils.probs_to_binary_bundle_specific(seg, ExpUtils.get_bundle_names(HP.CLASSES)[1:])

    #remove following two lines to keep super resolution
    seg = DatasetUtils.cut_and_scale_img_back_to_original_img(seg, transformation)
    seg = DatasetUtils.add_original_zero_padding_again(seg, bbox, original_shape, HP.NR_OF_CLASSES)
    ExpUtils.print_verbose(HP, "Took {}s".format(round(time.time() - start_time, 2)))

    return seg
Ejemplo n.º 8
0
    def get_batches(self, batch_size=1):

        num_processes = 1  # not not use more than 1 if you want to keep original slice order (Threads do return in random order)

        if self.HP.TYPE == "combined":
            # Load from Npy file for Fusion
            data = self.subject
            seg = []
            nr_of_samples = len([self.subject]) * self.HP.INPUT_DIM[0]
            num_batches = int(nr_of_samples / batch_size / num_processes)
            batch_gen = SlicesBatchGeneratorNpyImg_fusion(
                (data, seg),
                BATCH_SIZE=batch_size,
                num_batches=num_batches,
                seed=None)
        else:
            # Load Features
            if self.HP.FEATURES_FILENAME == "12g90g270g":
                data_img = nib.load(
                    join(self.data_dir, "270g_125mm_peaks.nii.gz"))
            else:
                data_img = nib.load(
                    join(self.data_dir, self.HP.FEATURES_FILENAME + ".nii.gz"))
            data = data_img.get_data()
            data = np.nan_to_num(data)
            data = DatasetUtils.scale_input_to_unet_shape(
                data, self.HP.DATASET, self.HP.RESOLUTION)
            # data = DatasetUtils.scale_input_to_unet_shape(data, "HCP_32g", "1.25mm")  #If we want to test HCP_32g on HighRes net

            #Load Segmentation
            if self.use_gt_mask:
                seg = nib.load(
                    join(self.data_dir,
                         self.HP.LABELS_FILENAME + ".nii.gz")).get_data()

                if self.HP.LABELS_FILENAME not in [
                        "bundle_peaks_11_808080", "bundle_peaks_20_808080",
                        "bundle_peaks_808080", "bundle_masks_20_808080",
                        "bundle_masks_72_808080", "bundle_peaks_Part1_808080",
                        "bundle_peaks_Part2_808080",
                        "bundle_peaks_Part3_808080",
                        "bundle_peaks_Part4_808080"
                ]:
                    if self.HP.DATASET in ["HCP_2mm", "HCP_2.5mm", "HCP_32g"]:
                        # By using "HCP" but lower resolution scale_input_to_unet_shape will automatically downsample the HCP sized seg_mask
                        seg = DatasetUtils.scale_input_to_unet_shape(
                            seg, "HCP", self.HP.RESOLUTION)
                    else:
                        seg = DatasetUtils.scale_input_to_unet_shape(
                            seg, self.HP.DATASET, self.HP.RESOLUTION)
            else:
                # Use dummy mask in case we only want to predict on some data (where we do not have Ground Truth))
                seg = np.zeros(
                    (self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[0],
                     self.HP.INPUT_DIM[0],
                     self.HP.NR_OF_CLASSES)).astype(self.HP.LABELS_TYPE)

            batch_gen = SlicesBatchGenerator((data, seg),
                                             batch_size=batch_size)

        batch_gen.HP = self.HP
        tfs = []  # transforms

        if self.HP.NORMALIZE_DATA:
            tfs.append(ZeroMeanUnitVarianceTransform(per_channel=False))

        if self.HP.TEST_TIME_DAUG:
            center_dist_from_border = int(
                self.HP.INPUT_DIM[0] / 2.) - 10  # (144,144) -> 62
            tfs.append(
                SpatialTransform(
                    self.HP.INPUT_DIM,
                    patch_center_dist_from_border=center_dist_from_border,
                    do_elastic_deform=True,
                    alpha=(90., 120.),
                    sigma=(9., 11.),
                    do_rotation=True,
                    angle_x=(-0.8, 0.8),
                    angle_y=(-0.8, 0.8),
                    angle_z=(-0.8, 0.8),
                    do_scale=True,
                    scale=(0.9, 1.5),
                    border_mode_data='constant',
                    border_cval_data=0,
                    order_data=3,
                    border_mode_seg='constant',
                    border_cval_seg=0,
                    order_seg=0,
                    random_crop=True))
            # tfs.append(ResampleTransform(zoom_range=(0.5, 1)))
            # tfs.append(GaussianNoiseTransform(noise_variance=(0, 0.05)))
            tfs.append(
                ContrastAugmentationTransform(contrast_range=(0.7, 1.3),
                                              preserve_range=True,
                                              per_channel=False))
            tfs.append(
                BrightnessMultiplicativeTransform(multiplier_range=(0.7, 1.3),
                                                  per_channel=False))

        tfs.append(ReorderSegTransform())
        batch_gen = MultiThreadedAugmenter(
            batch_gen,
            Compose(tfs),
            num_processes=num_processes,
            num_cached_per_queue=2,
            seeds=None
        )  # Only use num_processes=1, otherwise global_idx of SlicesBatchGenerator not working
        return batch_gen  # data: (batch_size, channels, x, y), seg: (batch_size, x, y, channels)
Ejemplo n.º 9
0
def run_tractseg(data, output_type="tract_segmentation", input_type="peaks",
                 single_orientation=False, verbose=False, dropout_sampling=False, threshold=0.5,
                 bundle_specific_threshold=False, get_probs=False, peak_threshold=0.1):
    '''
    Run TractSeg

    :param data: input peaks (4D numpy array with shape [x,y,z,9])
    :param output_type: "tract_segmentation" | "endings_segmentation" | "TOM" | "dm_regression"
    :param input_type: "peaks"
    :param verbose: show debugging infos
    :param dropout_sampling: create uncertainty map by monte carlo dropout (https://arxiv.org/abs/1506.02142)
    :param threshold: Threshold for converting probability map to binary map
    :param bundle_specific_threshold: Threshold is lower for some bundles which need more sensitivity (CA, CST, FX)
    :param get_probs: Output raw probability map instead of binary map
    :param peak_threshold: all peaks shorter than peak_threshold will be set to zero
    :return: 4D numpy array with the output of tractseg
        for tract_segmentation:     [x,y,z,nr_of_bundles]
        for endings_segmentation:   [x,y,z,2*nr_of_bundles]
        for TOM:                    [x,y,z,3*nr_of_bundles]
    '''
    start_time = time.time()

    config = get_config_name(input_type, output_type)
    HP = getattr(importlib.import_module("tractseg.config.PretrainedModels." + config), "HP")()
    HP.VERBOSE = verbose
    HP.TRAIN = False
    HP.TEST = False
    HP.SEGMENT = False
    HP.GET_PROBS = get_probs
    HP.LOAD_WEIGHTS = True
    HP.DROPOUT_SAMPLING = dropout_sampling
    HP.THRESHOLD = threshold

    if bundle_specific_threshold:
        HP.GET_PROBS = True

    if input_type == "peaks":
        if HP.EXPERIMENT_TYPE == "tract_segmentation" and HP.DROPOUT_SAMPLING:
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_dropout_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_12g90g270g_125mm_DAugAll_Dropout", "best_weights_ep114.npz")
        elif HP.EXPERIMENT_TYPE == "tract_segmentation":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "x_Pretrained_TractSeg_Models/TractSeg_T1_12g90g270g_125mm_DAugAll", "best_weights_ep392.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg72_888", "best_weights_ep247.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg72_888_SchizoFineT_lr001", "best_weights_ep186.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_12g90g270g_125mm_DS_DAugAll_RotMir", "best_weights_ep200.npz")
        elif HP.EXPERIMENT_TYPE == "endings_segmentation":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_endings_segmentation_v3.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "EndingsSeg_12g90g270g_125mm_DS_DAugAll", "best_weights_ep234.npz")
        elif HP.EXPERIMENT_TYPE == "peak_regression":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_peak_regression_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "x_Pretrained_TractSeg_Models/Peaks20_12g90g270g_125mm_DAugSimp_constW5", "best_weights_ep441.npz")  #more oversegmentation with DAug
        elif HP.EXPERIMENT_TYPE == "dm_regression":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_dm_regression_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "DmReg_12g90g270g_125mm_DAugAll_Ubuntu", "best_weights_ep80.npz")
    elif input_type == "T1":
        if HP.EXPERIMENT_TYPE == "tract_segmentation":
            # HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_v1.npz")
            HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes/x_Pretrained_TractSeg_Models", "TractSeg_T1_125mm_DAugAll", "best_weights_ep142.npz")
        elif HP.EXPERIMENT_TYPE == "endings_segmentation":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_endings_segmentation_v1.npz")
        elif HP.EXPERIMENT_TYPE == "peak_regression":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_peak_regression_v1.npz")
    print("Loading weights from: {}".format(HP.WEIGHTS_PATH))

    if HP.EXPERIMENT_TYPE == "peak_regression":
        HP.NR_OF_CLASSES = 3*len(ExpUtils.get_bundle_names(HP.CLASSES)[1:])
    else:
        HP.NR_OF_CLASSES = len(ExpUtils.get_bundle_names(HP.CLASSES)[1:])

    if HP.VERBOSE:
        print("Hyperparameters:")
        ExpUtils.print_HPs(HP)

    Utils.download_pretrained_weights(experiment_type=HP.EXPERIMENT_TYPE, dropout_sampling=HP.DROPOUT_SAMPLING)

    data = np.nan_to_num(data)
    # brain_mask = ImgUtils.simple_brain_mask(data)
    # if HP.VERBOSE:
    #     nib.save(nib.Nifti1Image(brain_mask, np.eye(4)), "otsu_brain_mask_DEBUG.nii.gz")

    if input_type == "T1":
        data = np.reshape(data, (data.shape[0], data.shape[1], data.shape[2], 1))
    data, seg_None, bbox, original_shape = DatasetUtils.crop_to_nonzero(data)
    data, transformation = DatasetUtils.pad_and_scale_img_to_square_img(data, target_size=HP.INPUT_DIM[0])

    model = BaseModel(HP)

    if HP.EXPERIMENT_TYPE == "tract_segmentation" or HP.EXPERIMENT_TYPE == "endings_segmentation" or HP.EXPERIMENT_TYPE == "dm_regression":
        if single_orientation:     # mainly needed for testing because of less RAM requirements
            dataManagerSingle = DataManagerSingleSubjectByFile(HP, data=data)
            trainerSingle = Trainer(model, dataManagerSingle)
            if HP.DROPOUT_SAMPLING or HP.EXPERIMENT_TYPE == "dm_regression" or HP.GET_PROBS:
                seg, img_y = trainerSingle.get_seg_single_img(HP, probs=True, scale_to_world_shape=False, only_prediction=True)
            else:
                seg, img_y = trainerSingle.get_seg_single_img(HP, probs=False, scale_to_world_shape=False, only_prediction=True)
        else:
            seg_xyz, gt = DirectionMerger.get_seg_single_img_3_directions(HP, model, data=data, scale_to_world_shape=False, only_prediction=True)
            if HP.DROPOUT_SAMPLING or HP.EXPERIMENT_TYPE == "dm_regression" or HP.GET_PROBS:
                seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=True)
            else:
                seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=False)

    elif HP.EXPERIMENT_TYPE == "peak_regression":
        dataManagerSingle = DataManagerSingleSubjectByFile(HP, data=data)
        trainerSingle = Trainer(model, dataManagerSingle)
        seg, img_y = trainerSingle.get_seg_single_img(HP, probs=True, scale_to_world_shape=False, only_prediction=True)
        if bundle_specific_threshold:
            seg = ImgUtils.remove_small_peaks_bundle_specific(seg, ExpUtils.get_bundle_names(HP.CLASSES)[1:], len_thr=0.3)
        else:
            seg = ImgUtils.remove_small_peaks(seg, len_thr=peak_threshold)
        #3 dir for Peaks -> not working (?)
        # seg_xyz, gt = DirectionMerger.get_seg_single_img_3_directions(HP, model, data=data, scale_to_world_shape=False, only_prediction=True)
        # seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=True)

    if bundle_specific_threshold and HP.EXPERIMENT_TYPE == "tract_segmentation":
        seg = ImgUtils.probs_to_binary_bundle_specific(seg, ExpUtils.get_bundle_names(HP.CLASSES)[1:])

    #remove following two lines to keep super resolution
    seg = DatasetUtils.cut_and_scale_img_back_to_original_img(seg, transformation)
    seg = DatasetUtils.add_original_zero_padding_again(seg, bbox, original_shape, HP.NR_OF_CLASSES)
    ExpUtils.print_verbose(HP, "Took {}s".format(round(time.time() - start_time, 2)))

    return seg
Ejemplo n.º 10
0
    def get_batches(self, batch_size=1):

        num_processes = 1   # not not use more than 1 if you want to keep original slice order (Threads do return in random order)

        if self.HP.TYPE == "combined":
            # Load from Npy file for Fusion
            data = self.subject
            seg = []
            nr_of_samples = len([self.subject]) * self.HP.INPUT_DIM[0]
            num_batches = int(nr_of_samples / batch_size / num_processes)
            batch_gen = SlicesBatchGeneratorNpyImg_fusion((data, seg), BATCH_SIZE=batch_size, num_batches=num_batches, seed=None)
        else:
            # Load Features
            if self.HP.FEATURES_FILENAME == "12g90g270g":
                data_img = nib.load(join(self.data_dir, "270g_125mm_peaks.nii.gz"))
            else:
                data_img = nib.load(join(self.data_dir, self.HP.FEATURES_FILENAME + ".nii.gz"))
            data = data_img.get_data()
            data = np.nan_to_num(data)
            data = DatasetUtils.scale_input_to_unet_shape(data, self.HP.DATASET, self.HP.RESOLUTION)
            # data = DatasetUtils.scale_input_to_unet_shape(data, "HCP_32g", "1.25mm")  #If we want to test HCP_32g on HighRes net

            #Load Segmentation
            if self.use_gt_mask:
                seg = nib.load(join(self.data_dir, self.HP.LABELS_FILENAME + ".nii.gz")).get_data()

                if self.HP.LABELS_FILENAME not in ["bundle_peaks_11_808080", "bundle_peaks_20_808080", "bundle_peaks_808080",
                                                   "bundle_masks_20_808080", "bundle_masks_72_808080"]:
                    if self.HP.DATASET in ["HCP_2mm", "HCP_2.5mm", "HCP_32g"]:
                        # By using "HCP" but lower resolution scale_input_to_unet_shape will automatically downsample the HCP sized seg_mask
                        seg = DatasetUtils.scale_input_to_unet_shape(seg, "HCP", self.HP.RESOLUTION)
                    else:
                        seg = DatasetUtils.scale_input_to_unet_shape(seg, self.HP.DATASET, self.HP.RESOLUTION)
            else:
                # Use dummy mask in case we only want to predict on some data (where we do not have Ground Truth))
                seg = np.zeros((self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[0], self.HP.NR_OF_CLASSES)).astype(self.HP.LABELS_TYPE)

            batch_gen = SlicesBatchGenerator((data, seg), BATCH_SIZE=batch_size)

        batch_gen.HP = self.HP
        tfs = []  # transforms

        if self.HP.NORMALIZE_DATA:
            tfs.append(ZeroMeanUnitVarianceTransform(per_channel=False))

        if self.HP.TEST_TIME_DAUG:
            center_dist_from_border = int(self.HP.INPUT_DIM[0] / 2.) - 10  # (144,144) -> 62
            tfs.append(SpatialTransform(self.HP.INPUT_DIM,
                                        patch_center_dist_from_border=center_dist_from_border,
                                        do_elastic_deform=True, alpha=(90., 120.), sigma=(9., 11.),
                                        do_rotation=True, angle_x=(-0.8, 0.8), angle_y=(-0.8, 0.8),
                                        angle_z=(-0.8, 0.8),
                                        do_scale=True, scale=(0.9, 1.5), border_mode_data='constant',
                                        border_cval_data=0,
                                        order_data=3,
                                        border_mode_seg='constant', border_cval_seg=0, order_seg=0, random_crop=True))
            # tfs.append(ResampleTransform(zoom_range=(0.5, 1)))
            # tfs.append(GaussianNoiseTransform(noise_variance=(0, 0.05)))
            tfs.append(ContrastAugmentationTransform(contrast_range=(0.7, 1.3), preserve_range=True, per_channel=False))
            tfs.append(BrightnessMultiplicativeTransform(multiplier_range=(0.7, 1.3), per_channel=False))


        tfs.append(ReorderSegTransform())
        batch_gen = MultiThreadedAugmenter(batch_gen, Compose(tfs), num_processes=num_processes, num_cached_per_queue=2, seeds=None) # Only use num_processes=1, otherwise global_idx of SlicesBatchGenerator not working
        return batch_gen  # data: (batch_size, channels, x, y), seg: (batch_size, x, y, channels)
Ejemplo n.º 11
0
    def _create_slices_file(HP, subjects, filename, slice, shuffle=True):
        data_dir = join(C.HOME, HP.DATASET_FOLDER)

        dwi_slices = []
        mask_slices = []

        print("\n\nProcessing Data...")
        for s in subjects:
            print("processing dwi subject {}".format(s))

            dwi = nib.load(join(data_dir, s, HP.FEATURES_FILENAME + ".nii.gz"))
            dwi_data = dwi.get_data()
            dwi_data = np.nan_to_num(dwi_data)
            dwi_data = DatasetUtils.scale_input_to_unet_shape(dwi_data, HP.DATASET, HP.RESOLUTION)

            # if slice == "x":
            #     for z in range(dwi_data.shape[0]):
            #         dwi_slices.append(dwi_data[z, :, :, :])
            #
            # if slice == "y":
            #     for z in range(dwi_data.shape[1]):
            #         dwi_slices.append(dwi_data[:, z, :, :])
            #
            # if slice == "z":
            #     for z in range(dwi_data.shape[2]):
            #         dwi_slices.append(dwi_data[:, :, z, :])

            #Use slices from all directions in one dataset
            for z in range(dwi_data.shape[0]):
                dwi_slices.append(dwi_data[z, :, :, :])
            for z in range(dwi_data.shape[1]):
                dwi_slices.append(dwi_data[:, z, :, :])
            for z in range(dwi_data.shape[2]):
                dwi_slices.append(dwi_data[:, :, z, :])

        dwi_slices = np.array(dwi_slices)
        random_idxs = None
        if shuffle:
            random_idxs = np.random.choice(len(dwi_slices), len(dwi_slices))
            dwi_slices = dwi_slices[random_idxs]

        np.save(filename + "_data.npy", dwi_slices)
        del dwi_slices  #free memory


        print("\n\nProcessing Segs...")
        for s in subjects:
            print("processing seg subject {}".format(s))

            mask_data = ImgUtils.create_multilabel_mask(HP, s, labels_type=HP.LABELS_TYPE)
            if HP.RESOLUTION == "2.5mm":
                mask_data = ImgUtils.resize_first_three_dims(mask_data, order=0, zoom=0.5)
            mask_data = DatasetUtils.scale_input_to_unet_shape(mask_data, HP.DATASET, HP.RESOLUTION)

            # if slice == "x":
            #     for z in range(dwi_data.shape[0]):
            #         mask_slices.append(mask_data[z, :, :, :])
            #
            # if slice == "y":
            #     for z in range(dwi_data.shape[1]):
            #         mask_slices.append(mask_data[:, z, :, :])
            #
            # if slice == "z":
            #     for z in range(dwi_data.shape[2]):
            #         mask_slices.append(mask_data[:, :, z, :])

            # Use slices from all directions in one dataset
            for z in range(dwi_data.shape[0]):
                mask_slices.append(mask_data[z, :, :, :])
            for z in range(dwi_data.shape[1]):
                mask_slices.append(mask_data[:, z, :, :])
            for z in range(dwi_data.shape[2]):
                mask_slices.append(mask_data[:, :, z, :])

        mask_slices = np.array(mask_slices)
        print("SEG TYPE: {}".format(mask_slices.dtype))
        if shuffle:
            mask_slices = mask_slices[random_idxs]

        np.save(filename + "_seg.npy", mask_slices)
Ejemplo n.º 12
0
    def generate_train_batch(self):
        subjects = self._data[0]
        subject_idx = int(random.uniform(0, len(subjects)))     # len(subjects)-1 not needed because int always rounds to floor

        for i in range(20):
            try:
                if np.random.random() < 0.5:
                    data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data()
                else:
                    data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data()

                # rnd_choice = np.random.random()
                # if rnd_choice < 0.33:
                #     data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data()
                # elif rnd_choice < 0.66:
                #     data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data()
                # else:
                #     data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "12g_125mm_peaks.nii.gz")).get_data()

                seg = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], self.HP.LABELS_FILENAME + ".nii.gz")).get_data()
                break
            except IOError:
                ExpUtils.print_and_save(self.HP, "\n\nWARNING: Could not load file. Trying again in 20s (Try number: " + str(i) + ").\n\n")
            ExpUtils.print_and_save(self.HP, "Sleeping 20s")
            sleep(20)
        # ExpUtils.print_and_save(self.HP, "Successfully loaded input.")

        data = np.nan_to_num(data)    # Needed otherwise not working
        seg = np.nan_to_num(seg)

        data = DatasetUtils.scale_input_to_unet_shape(data, self.HP.DATASET, self.HP.RESOLUTION)    # (x, y, z, channels)
        if self.HP.DATASET in ["HCP_2mm", "HCP_2.5mm", "HCP_32g"]:
            # By using "HCP" but lower resolution scale_input_to_unet_shape will automatically downsample the HCP sized seg_mask to the lower resolution
            seg = DatasetUtils.scale_input_to_unet_shape(seg, "HCP", self.HP.RESOLUTION)
        else:
            seg = DatasetUtils.scale_input_to_unet_shape(seg, self.HP.DATASET, self.HP.RESOLUTION)  # (x, y, z, classes)

        slice_idxs = np.random.choice(data.shape[0], self.BATCH_SIZE, False, None)

        # Randomly sample slice orientation
        slice_direction = int(round(random.uniform(0,2)))

        if slice_direction == 0:
            y = seg[slice_idxs, :, :].astype(self.HP.LABELS_TYPE)
            y = np.array(y).transpose(0, 3, 1, 2)  # nr_classes channel has to be before with and height for DataAugmentation (bs, nr_of_classes, x, y)
        elif slice_direction == 1:
            y = seg[:, slice_idxs, :].astype(self.HP.LABELS_TYPE)
            y = np.array(y).transpose(1, 3, 0, 2)
        elif slice_direction == 2:
            y = seg[:, :, slice_idxs].astype(self.HP.LABELS_TYPE)
            y = np.array(y).transpose(2, 3, 0, 1)


        sw = 5 #slice_window (only odd numbers allowed)
        pad = int((sw-1) / 2)

        data_pad = np.zeros((data.shape[0]+sw-1, data.shape[1]+sw-1, data.shape[2]+sw-1, data.shape[3])).astype(data.dtype)
        data_pad[pad:-pad, pad:-pad, pad:-pad, :] = data   #padded with two slices of zeros on all sides
        batch=[]
        for s_idx in slice_idxs:
            if slice_direction == 0:
                #(s_idx+2)-2:(s_idx+2)+3 = s_idx:s_idx+5
                x = data_pad[s_idx:s_idx+sw:, pad:-pad, pad:-pad, :].astype(np.float32)      # (5, y, z, channels)
                x = np.array(x).transpose(0, 3, 1, 2)  # channels dim has to be before width and height for Unet (but after batches)
                x = np.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))  # (5*channels, y, z)
                batch.append(x)
            elif slice_direction == 1:
                x = data_pad[pad:-pad, s_idx:s_idx+sw, pad:-pad, :].astype(np.float32)  # (5, y, z, channels)
                x = np.array(x).transpose(1, 3, 0, 2)  # channels dim has to be before width and height for Unet (but after batches)
                x = np.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))  # (5*channels, y, z)
                batch.append(x)
            elif slice_direction == 2:
                x = data_pad[pad:-pad, pad:-pad, s_idx:s_idx+sw, :].astype(np.float32)  # (5, y, z, channels)
                x = np.array(x).transpose(2, 3, 0, 1)  # channels dim has to be before width and height for Unet (but after batches)
                x = np.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))  # (5*channels, y, z)
                batch.append(x)
        data_dict = {"data": np.array(batch),     # (batch_size, channels, x, y, [z])
                     "seg": y}                    # (batch_size, channels, x, y, [z])

        return data_dict
Ejemplo n.º 13
0
    def _create_slices_file(HP, subjects, filename, slice, shuffle=True):
        data_dir = join(C.HOME, HP.DATASET_FOLDER)

        dwi_slices = []
        mask_slices = []

        print("\n\nProcessing Data...")
        for s in subjects:
            print("processing dwi subject {}".format(s))

            dwi = nib.load(join(data_dir, s, HP.FEATURES_FILENAME + ".nii.gz"))
            dwi_data = dwi.get_data()
            dwi_data = np.nan_to_num(dwi_data)
            dwi_data = DatasetUtils.scale_input_to_unet_shape(
                dwi_data, HP.DATASET, HP.RESOLUTION)

            # if slice == "x":
            #     for z in range(dwi_data.shape[0]):
            #         dwi_slices.append(dwi_data[z, :, :, :])
            #
            # if slice == "y":
            #     for z in range(dwi_data.shape[1]):
            #         dwi_slices.append(dwi_data[:, z, :, :])
            #
            # if slice == "z":
            #     for z in range(dwi_data.shape[2]):
            #         dwi_slices.append(dwi_data[:, :, z, :])

            #Use slices from all directions in one dataset
            for z in range(dwi_data.shape[0]):
                dwi_slices.append(dwi_data[z, :, :, :])
            for z in range(dwi_data.shape[1]):
                dwi_slices.append(dwi_data[:, z, :, :])
            for z in range(dwi_data.shape[2]):
                dwi_slices.append(dwi_data[:, :, z, :])

        dwi_slices = np.array(dwi_slices)
        random_idxs = None
        if shuffle:
            random_idxs = np.random.choice(len(dwi_slices), len(dwi_slices))
            dwi_slices = dwi_slices[random_idxs]

        np.save(filename + "_data.npy", dwi_slices)
        del dwi_slices  #free memory

        print("\n\nProcessing Segs...")
        for s in subjects:
            print("processing seg subject {}".format(s))

            mask_data = ImgUtils.create_multilabel_mask(
                HP, s, labels_type=HP.LABELS_TYPE)
            if HP.RESOLUTION == "2.5mm":
                mask_data = ImgUtils.resize_first_three_dims(mask_data,
                                                             order=0,
                                                             zoom=0.5)
            mask_data = DatasetUtils.scale_input_to_unet_shape(
                mask_data, HP.DATASET, HP.RESOLUTION)

            # if slice == "x":
            #     for z in range(dwi_data.shape[0]):
            #         mask_slices.append(mask_data[z, :, :, :])
            #
            # if slice == "y":
            #     for z in range(dwi_data.shape[1]):
            #         mask_slices.append(mask_data[:, z, :, :])
            #
            # if slice == "z":
            #     for z in range(dwi_data.shape[2]):
            #         mask_slices.append(mask_data[:, :, z, :])

            # Use slices from all directions in one dataset
            for z in range(dwi_data.shape[0]):
                mask_slices.append(mask_data[z, :, :, :])
            for z in range(dwi_data.shape[1]):
                mask_slices.append(mask_data[:, z, :, :])
            for z in range(dwi_data.shape[2]):
                mask_slices.append(mask_data[:, :, z, :])

        mask_slices = np.array(mask_slices)
        print("SEG TYPE: {}".format(mask_slices.dtype))
        if shuffle:
            mask_slices = mask_slices[random_idxs]

        np.save(filename + "_seg.npy", mask_slices)
Ejemplo n.º 14
0
    def _create_prob_slices_file(HP, subjects, filename, bundle, shuffle=True):

        mask_dir = join(C.HOME, HP.DATASET_FOLDER)

        input_dir = HP.MULTI_PARENT_PATH

        combined_slices = []
        mask_slices = []

        for s in subjects:
            print("processing subject {}".format(s))

            probs_x = nib.load(
                join(input_dir, "UNet_x_" + str(HP.CV_FOLD), "probmaps",
                     s + "_probmap.nii.gz")).get_data()
            probs_y = nib.load(
                join(input_dir, "UNet_y_" + str(HP.CV_FOLD), "probmaps",
                     s + "_probmap.nii.gz")).get_data()
            probs_z = nib.load(
                join(input_dir, "UNet_z_" + str(HP.CV_FOLD), "probmaps",
                     s + "_probmap.nii.gz")).get_data()
            # probs_x = DatasetUtils.scale_input_to_unet_shape(probs_x, HP.DATASET, HP.RESOLUTION)
            # probs_y = DatasetUtils.scale_input_to_unet_shape(probs_y, HP.DATASET, HP.RESOLUTION)
            # probs_z = DatasetUtils.scale_input_to_unet_shape(probs_z, HP.DATASET, HP.RESOLUTION)
            combined = np.stack(
                (probs_x, probs_y, probs_z), axis=4
            )  # (73, 87, 73, 18, 3)  #not working alone: one dim too much for UNet -> reshape
            combined = np.reshape(
                combined,
                (combined.shape[0], combined.shape[1], combined.shape[2],
                 combined.shape[3] * combined.shape[4]))  # (73, 87, 73, 3*18)

            # print("combined shape after", combined.shape)

            mask_data = ImgUtils.create_multilabel_mask(
                HP, s, labels_type=HP.LABELS_TYPE)
            if HP.DATASET == "HCP_2mm":
                #use "HCP" because for mask we need downscaling
                mask_data = DatasetUtils.scale_input_to_unet_shape(
                    mask_data, "HCP", HP.RESOLUTION)
            elif HP.DATASET == "HCP_2.5mm":
                # use "HCP" because for mask we need downscaling
                mask_data = DatasetUtils.scale_input_to_unet_shape(
                    mask_data, "HCP", HP.RESOLUTION)
            else:
                # Mask has same resolution as probmaps -> we can use same resizing
                mask_data = DatasetUtils.scale_input_to_unet_shape(
                    mask_data, HP.DATASET, HP.RESOLUTION)

            # Save as Img
            img = nib.Nifti1Image(
                combined, ImgUtils.get_dwi_affine(HP.DATASET, HP.RESOLUTION))
            nib.save(
                img,
                join(HP.EXP_PATH, "combined", s + "_combinded_probmap.nii.gz"))

            combined = DatasetUtils.scale_input_to_unet_shape(
                combined, HP.DATASET, HP.RESOLUTION)
            assert (combined.shape[2] == mask_data.shape[2])

            #Save as Slices
            for z in range(combined.shape[2]):
                combined_slices.append(combined[:, :, z, :])
                mask_slices.append(mask_data[:, :, z, :])

        if shuffle:
            combined_slices, mask_slices = sk_shuffle(combined_slices,
                                                      mask_slices,
                                                      random_state=9)

        if HP.TRAIN:
            np.save(filename + "_data.npy", combined_slices)
            np.save(filename + "_seg.npy", mask_slices)
Ejemplo n.º 15
0
    def get_seg_single_img(self, HP, probs=False, scale_to_world_shape=True):
        '''
        Returns layers for one image (batch manager is only allowed to return batches for one image)

        :param HP:
        :return: ([146, 174, 146, nrClasses], [146, 174, 146, nrClasses])    (Prediction, Groundtruth)
        '''

        #Test Time DAug
        for i in range(1):
            # segs = []
            # ys = []

            layers_seg = []
            layers_y = []
            batch_generator = self.dataManager.get_batches(batch_size=1)
            batch_generator = list(batch_generator)
            for j in tqdm(list(range(len(batch_generator)))):
                batch = batch_generator[j]
                x = batch["data"]  # (bs, nr_of_channels, x, y)
                y = batch["seg"]  # (bs, x, y, nr_of_classes)
                y = y.astype(HP.LABELS_TYPE)
                y = np.squeeze(
                    y
                )  # remove bs dimension which is only 1 -> (x, y, nrClasses)

                #For normal prediction
                layer_probs = self.model.get_probs(x)  # (bs, x, y, nrClasses)
                layer_probs = np.squeeze(
                    layer_probs
                )  # remove bs dimension which is only 1 -> (x, y, nrClasses)

                #For Dropout Sampling (must set Deterministic=False in model)
                # NR_SAMPLING = 30
                # samples = []
                # for i in range(NR_SAMPLING):
                #     layer_probs = self.model.get_probs(x)  # (bs, x, y, nrClasses)
                #     samples.append(layer_probs)
                #
                # samples = np.array(samples)  # (NR_SAMPLING, bs, x, y, nrClasses)
                # samples = np.squeeze(samples) # (NR_SAMPLING, x, y, nrClasses)
                # layer_probs = np.mean(samples, axis=0)
                # #layer_probs = np.std(samples, axis=0)    #use std

                if probs:
                    seg = layer_probs  # (x, y, nrClasses)
                else:
                    seg = layer_probs
                    seg[seg >= HP.THRESHOLD] = 1
                    seg[seg < HP.THRESHOLD] = 0
                    seg = seg.astype(np.int16)

                layers_seg.append(seg)
                layers_y.append(y)
            layers_seg = np.array(layers_seg)
            layers_y = np.array(layers_y)

        #Get in right order (x,y,z) and
        if HP.SLICE_DIRECTION == "x":
            layers_seg = layers_seg.transpose(0, 1, 2, 3)
            layers_y = layers_y.transpose(0, 1, 2, 3)

        elif HP.SLICE_DIRECTION == "y":
            layers_seg = layers_seg.transpose(1, 0, 2, 3)
            layers_y = layers_y.transpose(1, 0, 2, 3)

        elif HP.SLICE_DIRECTION == "z":
            layers_seg = layers_seg.transpose(1, 2, 0, 3)
            layers_y = layers_y.transpose(1, 2, 0, 3)

        if scale_to_world_shape:
            layers_seg = DatasetUtils.scale_input_to_world_shape(
                layers_seg, HP.DATASET, HP.RESOLUTION)
            layers_y = DatasetUtils.scale_input_to_world_shape(
                layers_y, HP.DATASET, HP.RESOLUTION)

        layers_seg = layers_seg.astype(np.float32)
        layers_y = layers_y.astype(np.float32)

        return layers_seg, layers_y  # (Prediction, Groundtruth)
Ejemplo n.º 16
0
    def generate_train_batch(self):
        subjects = self._data[0]
        subject_idx = int(
            random.uniform(0, len(subjects))
        )  # len(subjects)-1 not needed because int always rounds to floor

        for i in range(20):
            try:
                if self.HP.FEATURES_FILENAME == "12g90g270g":
                    # if np.random.random() < 0.5:
                    #     data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data()
                    # else:
                    #     data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data()

                    rnd_choice = np.random.random()
                    if rnd_choice < 0.33:
                        data = nib.load(
                            join(C.DATA_PATH, self.HP.DATASET_FOLDER,
                                 subjects[subject_idx],
                                 "270g_125mm_peaks.nii.gz")).get_data()
                    elif rnd_choice < 0.66:
                        data = nib.load(
                            join(C.DATA_PATH, self.HP.DATASET_FOLDER,
                                 subjects[subject_idx],
                                 "90g_125mm_peaks.nii.gz")).get_data()
                    else:
                        data = nib.load(
                            join(C.DATA_PATH, self.HP.DATASET_FOLDER,
                                 subjects[subject_idx],
                                 "12g_125mm_peaks.nii.gz")).get_data()
                elif self.HP.FEATURES_FILENAME == "T1_Peaks270g":
                    peaks = nib.load(
                        join(C.DATA_PATH, self.HP.DATASET_FOLDER,
                             subjects[subject_idx],
                             "270g_125mm_peaks.nii.gz")).get_data()
                    t1 = nib.load(
                        join(C.DATA_PATH, self.HP.DATASET_FOLDER,
                             subjects[subject_idx], "T1.nii.gz")).get_data()
                    data = np.concatenate((peaks, t1), axis=3)
                elif self.HP.FEATURES_FILENAME == "T1_Peaks12g90g270g":
                    rnd_choice = np.random.random()
                    if rnd_choice < 0.33:
                        peaks = nib.load(
                            join(C.DATA_PATH, self.HP.DATASET_FOLDER,
                                 subjects[subject_idx],
                                 "270g_125mm_peaks.nii.gz")).get_data()
                    elif rnd_choice < 0.66:
                        peaks = nib.load(
                            join(C.DATA_PATH, self.HP.DATASET_FOLDER,
                                 subjects[subject_idx],
                                 "90g_125mm_peaks.nii.gz")).get_data()
                    else:
                        peaks = nib.load(
                            join(C.DATA_PATH, self.HP.DATASET_FOLDER,
                                 subjects[subject_idx],
                                 "12g_125mm_peaks.nii.gz")).get_data()
                    t1 = nib.load(
                        join(C.DATA_PATH, self.HP.DATASET_FOLDER,
                             subjects[subject_idx], "T1.nii.gz")).get_data()
                    data = np.concatenate((peaks, t1), axis=3)
                else:
                    data = nib.load(
                        join(C.DATA_PATH, self.HP.DATASET_FOLDER,
                             subjects[subject_idx], self.HP.FEATURES_FILENAME +
                             ".nii.gz")).get_data()

                seg = nib.load(
                    join(C.DATA_PATH, self.HP.DATASET_FOLDER,
                         subjects[subject_idx],
                         self.HP.LABELS_FILENAME + ".nii.gz")).get_data()
                break
            except IOError:
                ExpUtils.print_and_save(
                    self.HP,
                    "\n\nWARNING: Could not load file. Trying again in 20s (Try number: "
                    + str(i) + ").\n\n")
            ExpUtils.print_and_save(self.HP, "Sleeping 20s")
            sleep(20)
        # ExpUtils.print_and_save(self.HP, "Successfully loaded input.")

        data = np.nan_to_num(data)  # Needed otherwise not working
        seg = np.nan_to_num(seg)

        data = DatasetUtils.scale_input_to_unet_shape(
            data, self.HP.DATASET, self.HP.RESOLUTION)  # (x, y, z, channels)

        if self.HP.LABELS_FILENAME not in [
                "bundle_peaks_11_808080", "bundle_peaks_20_808080",
                "bundle_peaks_808080", "bundle_masks_20_808080",
                "bundle_masks_72_808080"
        ]:
            if self.HP.DATASET in ["HCP_2mm", "HCP_2.5mm", "HCP_32g"]:
                # By using "HCP" but lower resolution scale_input_to_unet_shape will automatically downsample the HCP sized seg_mask to the lower resolution
                seg = DatasetUtils.scale_input_to_unet_shape(
                    seg, "HCP", self.HP.RESOLUTION)
            else:
                seg = DatasetUtils.scale_input_to_unet_shape(
                    seg, self.HP.DATASET,
                    self.HP.RESOLUTION)  # (x, y, z, classes)

        slice_idxs = np.random.choice(data.shape[0], self.BATCH_SIZE, False,
                                      None)

        # Randomly sample slice orientation
        if self.HP.TRAINING_SLICE_DIRECTION == "xyz":
            slice_direction = int(round(random.uniform(0, 2)))
        else:
            slice_direction = 1  #always use Y

        if slice_direction == 0:
            x = data[slice_idxs, :, :].astype(
                np.float32)  # (batch_size, y, z, channels)
            y = seg[slice_idxs, :, :].astype(self.HP.LABELS_TYPE)
            x = np.array(x).transpose(
                0, 3, 1, 2
            )  # depth-channel has to be before width and height for Unet (but after batches)
            y = np.array(y).transpose(
                0, 3, 1, 2
            )  # nr_classes channel has to be before with and height for DataAugmentation (bs, nr_of_classes, x, y)
        elif slice_direction == 1:
            x = data[:, slice_idxs, :].astype(
                np.float32)  # (x, batch_size, z, channels)
            y = seg[:, slice_idxs, :].astype(self.HP.LABELS_TYPE)
            x = np.array(x).transpose(1, 3, 0, 2)
            y = np.array(y).transpose(1, 3, 0, 2)
        elif slice_direction == 2:
            x = data[:, :, slice_idxs].astype(
                np.float32)  # (x, y, batch_size, channels)
            y = seg[:, :, slice_idxs].astype(self.HP.LABELS_TYPE)
            x = np.array(x).transpose(2, 3, 0, 1)
            y = np.array(y).transpose(2, 3, 0, 1)

        data_dict = {
            "data": x,  # (batch_size, channels, x, y, [z])
            "seg": y
        }  # (batch_size, channels, x, y, [z])
        return data_dict
Ejemplo n.º 17
0
    def generate_train_batch(self):
        subjects = self._data[0]
        subject_idx = int(
            random.uniform(0, len(subjects))
        )  # len(subjects)-1 not needed because int always rounds to floor

        for i in range(20):
            try:
                if np.random.random() < 0.5:
                    data = nib.load(
                        join(C.DATA_PATH, self.HP.DATASET_FOLDER,
                             subjects[subject_idx],
                             "270g_125mm_peaks.nii.gz")).get_data()
                else:
                    data = nib.load(
                        join(C.DATA_PATH, self.HP.DATASET_FOLDER,
                             subjects[subject_idx],
                             "90g_125mm_peaks.nii.gz")).get_data()

                # rnd_choice = np.random.random()
                # if rnd_choice < 0.33:
                #     data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data()
                # elif rnd_choice < 0.66:
                #     data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data()
                # else:
                #     data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "12g_125mm_peaks.nii.gz")).get_data()

                seg = nib.load(
                    join(C.DATA_PATH, self.HP.DATASET_FOLDER,
                         subjects[subject_idx],
                         self.HP.LABELS_FILENAME + ".nii.gz")).get_data()
                break
            except IOError:
                ExpUtils.print_and_save(
                    self.HP,
                    "\n\nWARNING: Could not load file. Trying again in 20s (Try number: "
                    + str(i) + ").\n\n")
            ExpUtils.print_and_save(self.HP, "Sleeping 20s")
            sleep(20)
        # ExpUtils.print_and_save(self.HP, "Successfully loaded input.")

        data = np.nan_to_num(data)  # Needed otherwise not working
        seg = np.nan_to_num(seg)

        data = DatasetUtils.scale_input_to_unet_shape(
            data, self.HP.DATASET, self.HP.RESOLUTION)  # (x, y, z, channels)
        if self.HP.DATASET in ["HCP_2mm", "HCP_2.5mm", "HCP_32g"]:
            # By using "HCP" but lower resolution scale_input_to_unet_shape will automatically downsample the HCP sized seg_mask to the lower resolution
            seg = DatasetUtils.scale_input_to_unet_shape(
                seg, "HCP", self.HP.RESOLUTION)
        else:
            seg = DatasetUtils.scale_input_to_unet_shape(
                seg, self.HP.DATASET, self.HP.RESOLUTION)  # (x, y, z, classes)

        slice_idxs = np.random.choice(data.shape[0], self.BATCH_SIZE, False,
                                      None)

        # Randomly sample slice orientation
        slice_direction = int(round(random.uniform(0, 2)))

        if slice_direction == 0:
            y = seg[slice_idxs, :, :].astype(self.HP.LABELS_TYPE)
            y = np.array(y).transpose(
                0, 3, 1, 2
            )  # nr_classes channel has to be before with and height for DataAugmentation (bs, nr_of_classes, x, y)
        elif slice_direction == 1:
            y = seg[:, slice_idxs, :].astype(self.HP.LABELS_TYPE)
            y = np.array(y).transpose(1, 3, 0, 2)
        elif slice_direction == 2:
            y = seg[:, :, slice_idxs].astype(self.HP.LABELS_TYPE)
            y = np.array(y).transpose(2, 3, 0, 1)

        sw = 5  #slice_window (only odd numbers allowed)
        pad = int((sw - 1) / 2)

        data_pad = np.zeros(
            (data.shape[0] + sw - 1, data.shape[1] + sw - 1,
             data.shape[2] + sw - 1, data.shape[3])).astype(data.dtype)
        data_pad[
            pad:-pad, pad:-pad,
            pad:-pad, :] = data  #padded with two slices of zeros on all sides
        batch = []
        for s_idx in slice_idxs:
            if slice_direction == 0:
                #(s_idx+2)-2:(s_idx+2)+3 = s_idx:s_idx+5
                x = data_pad[s_idx:s_idx + sw:, pad:-pad, pad:-pad, :].astype(
                    np.float32)  # (5, y, z, channels)
                x = np.array(x).transpose(
                    0, 3, 1, 2
                )  # channels dim has to be before width and height for Unet (but after batches)
                x = np.reshape(x, (x.shape[0] * x.shape[1], x.shape[2],
                                   x.shape[3]))  # (5*channels, y, z)
                batch.append(x)
            elif slice_direction == 1:
                x = data_pad[pad:-pad, s_idx:s_idx + sw, pad:-pad, :].astype(
                    np.float32)  # (5, y, z, channels)
                x = np.array(x).transpose(
                    1, 3, 0, 2
                )  # channels dim has to be before width and height for Unet (but after batches)
                x = np.reshape(x, (x.shape[0] * x.shape[1], x.shape[2],
                                   x.shape[3]))  # (5*channels, y, z)
                batch.append(x)
            elif slice_direction == 2:
                x = data_pad[pad:-pad, pad:-pad, s_idx:s_idx + sw, :].astype(
                    np.float32)  # (5, y, z, channels)
                x = np.array(x).transpose(
                    2, 3, 0, 1
                )  # channels dim has to be before width and height for Unet (but after batches)
                x = np.reshape(x, (x.shape[0] * x.shape[1], x.shape[2],
                                   x.shape[3]))  # (5*channels, y, z)
                batch.append(x)
        data_dict = {
            "data": np.array(batch),  # (batch_size, channels, x, y, [z])
            "seg": y
        }  # (batch_size, channels, x, y, [z])

        return data_dict
Ejemplo n.º 18
0
    def generate_train_batch(self):
        subjects = self._data[0]
        subject_idx = int(random.uniform(0, len(subjects)))     # len(subjects)-1 not needed because int always rounds to floor

        for i in range(20):
            try:
                if self.HP.FEATURES_FILENAME == "12g90g270g":
                    # if np.random.random() < 0.5:
                    #     data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data()
                    # else:
                    #     data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data()

                    rnd_choice = np.random.random()
                    if rnd_choice < 0.33:
                        data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data()
                    elif rnd_choice < 0.66:
                        data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data()
                    else:
                        data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "12g_125mm_peaks.nii.gz")).get_data()
                elif self.HP.FEATURES_FILENAME == "T1_Peaks270g":
                    peaks = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data()
                    t1 = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "T1.nii.gz")).get_data()
                    data = np.concatenate((peaks, t1), axis=3)
                elif self.HP.FEATURES_FILENAME == "T1_Peaks12g90g270g":
                    rnd_choice = np.random.random()
                    if rnd_choice < 0.33:
                        peaks = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data()
                    elif rnd_choice < 0.66:
                        peaks = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data()
                    else:
                        peaks = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "12g_125mm_peaks.nii.gz")).get_data()
                    t1 = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "T1.nii.gz")).get_data()
                    data = np.concatenate((peaks, t1), axis=3)
                else:
                    data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], self.HP.FEATURES_FILENAME + ".nii.gz")).get_data()

                seg = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], self.HP.LABELS_FILENAME + ".nii.gz")).get_data()
                break
            except IOError:
                ExpUtils.print_and_save(self.HP, "\n\nWARNING: Could not load file. Trying again in 20s (Try number: " + str(i) + ").\n\n")
            ExpUtils.print_and_save(self.HP, "Sleeping 20s")
            sleep(20)
        # ExpUtils.print_and_save(self.HP, "Successfully loaded input.")

        data = np.nan_to_num(data)    # Needed otherwise not working
        seg = np.nan_to_num(seg)

        data = DatasetUtils.scale_input_to_unet_shape(data, self.HP.DATASET, self.HP.RESOLUTION)    # (x, y, z, channels)

        if self.HP.LABELS_FILENAME not in ["bundle_peaks_11_808080", "bundle_peaks_20_808080", "bundle_peaks_808080",
                                           "bundle_masks_20_808080", "bundle_masks_72_808080"]:
            if self.HP.DATASET in ["HCP_2mm", "HCP_2.5mm", "HCP_32g"]:
                # By using "HCP" but lower resolution scale_input_to_unet_shape will automatically downsample the HCP sized seg_mask to the lower resolution
                seg = DatasetUtils.scale_input_to_unet_shape(seg, "HCP", self.HP.RESOLUTION)
            else:
                seg = DatasetUtils.scale_input_to_unet_shape(seg, self.HP.DATASET, self.HP.RESOLUTION)  # (x, y, z, classes)

        slice_idxs = np.random.choice(data.shape[0], self.BATCH_SIZE, False, None)

        # Randomly sample slice orientation
        if self.HP.TRAINING_SLICE_DIRECTION == "xyz":
            slice_direction = int(round(random.uniform(0,2)))
        else:
            slice_direction = 1 #always use Y

        if slice_direction == 0:
            x = data[slice_idxs, :, :].astype(np.float32)      # (batch_size, y, z, channels)
            y = seg[slice_idxs, :, :].astype(self.HP.LABELS_TYPE)
            x = np.array(x).transpose(0, 3, 1, 2)  # depth-channel has to be before width and height for Unet (but after batches)
            y = np.array(y).transpose(0, 3, 1, 2)  # nr_classes channel has to be before with and height for DataAugmentation (bs, nr_of_classes, x, y)
        elif slice_direction == 1:
            x = data[:, slice_idxs, :].astype(np.float32)      # (x, batch_size, z, channels)
            y = seg[:, slice_idxs, :].astype(self.HP.LABELS_TYPE)
            x = np.array(x).transpose(1, 3, 0, 2)
            y = np.array(y).transpose(1, 3, 0, 2)
        elif slice_direction == 2:
            x = data[:, :, slice_idxs].astype(np.float32)      # (x, y, batch_size, channels)
            y = seg[:, :, slice_idxs].astype(self.HP.LABELS_TYPE)
            x = np.array(x).transpose(2, 3, 0, 1)
            y = np.array(y).transpose(2, 3, 0, 1)

        data_dict = {"data": x,     # (batch_size, channels, x, y, [z])
                     "seg": y}      # (batch_size, channels, x, y, [z])
        return data_dict