Exemple #1
0
def create_preprocessed_files(subject):

    # Estimate bounding box from this file and then apply it to all other files
    bb_file = "12g_125mm_peaks"

    # todo: adapt
    # filenames_data = ["12g_125mm_peaks", "90g_125mm_peaks", "270g_125mm_peaks",
    #                   "12g_125mm_bedpostx_peaks_scaled", "90g_125mm_bedpostx_peaks_scaled",
    #                   "270g_125mm_bedpostx_peaks_scaled"]
    # filenames_seg = ["bundle_masks_72", "bundle_masks_dm", "endpoints_72_ordered",
    #                  "bundle_peaks_Part1", "bundle_peaks_Part2", "bundle_peaks_Part3", "bundle_peaks_Part4",
    #                  "bundle_masks_autoPTX_dm", "bundle_masks_autoPTX_thr001"]

    filenames_data = ["bundle_uncertainties"]
    filenames_seg = []

    print("idx: {}".format(subjects.index(subject)))
    exp_utils.make_dir(join(C.DATA_PATH, DATASET_FOLDER_PREPROC, subject))

    # Get bounding box
    data = nib.load(
        join(C.NETWORK_DRIVE, DATASET_FOLDER, subject,
             bb_file + ".nii.gz")).get_data()
    _, _, bbox, _ = data_utils.crop_to_nonzero(np.nan_to_num(data))

    for idx, filename in enumerate(filenames_data):
        path = join(C.NETWORK_DRIVE, DATASET_FOLDER, subject,
                    filename + ".nii.gz")
        if os.path.exists(path):
            img = nib.load(path)
            data = img.get_data()
            affine = img.affine
            data = np.nan_to_num(data)

            # Add channel dimension if does not exist yet
            if len(data.shape) == 3:
                data = data[..., None]

            data, _, _, _ = data_utils.crop_to_nonzero(data, bbox=bbox)

            # np.save(join(C.DATA_PATH, DATASET_FOLDER_PREPROC, subject, filename + ".npy"), data)
            nib.save(
                nib.Nifti1Image(data, affine),
                join(C.DATA_PATH, DATASET_FOLDER_PREPROC, subject,
                     filename + ".nii.gz"))
        else:
            print("skipping file: {}-{}".format(subject, idx))
            raise IOError("File missing")

    for filename in filenames_seg:
        img = nib.load(
            join(C.NETWORK_DRIVE, DATASET_FOLDER, subject,
                 filename + ".nii.gz"))
        data = img.get_data()
        data, _, _, _ = data_utils.crop_to_nonzero(data, bbox=bbox)
        # np.save(join(C.DATA_PATH, DATASET_FOLDER_PREPROC, subject, filename + ".npy"), data)
        nib.save(
            nib.Nifti1Image(data, img.affine),
            join(C.DATA_PATH, DATASET_FOLDER_PREPROC, subject,
                 filename + ".nii.gz"))
def preprocess_nifti_file(tom_fn, mask_fn, seeds_fn, ends_fn, beginning_fn, ending_fn):
    # Get the appropriate affine transformation from the mask
    mask_object = nib.load(mask_fn[0])
    mask_data   = mask_object.get_data()
    affine      = mask_object.affine

    # Compute the bounding box using the mask
    if np.sum(mask_data) != 0:
        bbox = data_utils.get_bbox_from_mask(np.nan_to_num(mask_data), 0)
    else:
        bbox = [[0,mask_data.shape[0]], [0,mask_data.shape[1]], [0,mask_data.shape[2]]]

    # Perform the cropping/padding for all the files:
    for fn_in, fn_out in [tom_fn, mask_fn, seeds_fn, ends_fn, beginning_fn, ending_fn]:
        data = nib.load(fn_in).get_data()

        # Adjust the data to have 4 dims, since the cropping will iterate over this dim
        if len(data.shape) == 3:
            data = data[..., None]

        # Crop and pad to create a cube volume of size 144 x 144 x 144
        data, _, _, _ = data_utils.crop_to_nonzero(np.nan_to_num(data), bbox=bbox)
        data, transform = data_utils.pad_and_scale_img_to_square_img(data)

        # Save the new file
        nib.save(nib.Nifti1Image(data, affine), fn_out)
Exemple #3
0
def run_tractseg(data,
                 output_type="tract_segmentation",
                 single_orientation=False,
                 dropout_sampling=False,
                 threshold=0.5,
                 bundle_specific_postprocessing=True,
                 get_probs=False,
                 peak_threshold=0.1,
                 postprocess=False,
                 peak_regression_part="All",
                 input_type="peaks",
                 blob_size_thr=50,
                 nr_cpus=-1,
                 verbose=False,
                 manual_exp_name=None,
                 inference_batch_size=1,
                 tract_definition="TractQuerier+",
                 bedpostX_input=False,
                 tract_segmentations_path=None,
                 TOM_dilation=1,
                 unit_test=False):
    """
    Run TractSeg

    Args:
        data: input peaks (4D numpy array with shape [x,y,z,9])
        output_type: TractSeg can segment not only bundles, but also the end regions of bundles.
            Moreover it can create Tract Orientation Maps (TOM).
            'tract_segmentation' [DEFAULT]: Segmentation of bundles (72 bundles).
            'endings_segmentation': Segmentation of bundle end regions (72 bundles).
            'TOM': Tract Orientation Maps (20 bundles).
        single_orientation: Do not run model 3 times along x/y/z orientation with subsequent mean fusion.
        dropout_sampling: Create uncertainty map by monte carlo dropout (https://arxiv.org/abs/1506.02142)
        threshold: Threshold for converting probability map to binary map
        bundle_specific_postprocessing: Set threshold to lower and use hole closing for CA nd FX if incomplete
        get_probs: Output raw probability map instead of binary map
        peak_threshold: All peaks shorter than peak_threshold will be set to zero
        postprocess: Simple postprocessing of segmentations: Remove small blobs and fill holes
        peak_regression_part: Only relevant for output type 'TOM'. If set to 'All' (default) it will return all
            72 bundles. If set to 'Part1'-'Part4' it will only run for a subset of the bundles to reduce memory
            load.
        input_type: Always set to "peaks"
        blob_size_thr: If setting postprocess to True, all blobs having a smaller number of voxels than specified in
            this threshold will be removed.
        nr_cpus: Number of CPUs to use. -1 means all available CPUs.
        verbose: Show debugging infos
        manual_exp_name: Name of experiment if do not want to use pretrained model but your own one
        inference_batch_size: batch size (higher: a bit faster but needs more RAM)
        tract_definition: Select which tract definitions to use. 'TractQuerier+' defines tracts mainly by their
            cortical start and end region. 'xtract' defines tracts mainly by ROIs in white matter.
        bedpostX_input: Input peaks are generated by bedpostX
        tract_segmentations_path: path to the bundle_segmentations (only needed for peak regression to remove peaks
            outside of the segmentation mask)
        TOM_dilation: Dilation applied to the tract segmentations before using them to mask the TOMs.

    Returns:
        4D numpy array with the output of tractseg
        for tract_segmentation:     [x, y, z, nr_of_bundles]
        for endings_segmentation:   [x, y, z, 2*nr_of_bundles]
        for TOM:                    [x, y, z, 3*nr_of_bundles]
    """
    start_time = time.time()

    if manual_exp_name is None:
        config = get_config_name(input_type,
                                 output_type,
                                 dropout_sampling=dropout_sampling,
                                 tract_definition=tract_definition)
        Config = getattr(
            importlib.import_module("tractseg.experiments.pretrained_models." +
                                    config), "Config")()
    else:
        Config = exp_utils.load_config_from_txt(
            join(C.EXP_PATH,
                 exp_utils.get_manual_exp_name_peaks(manual_exp_name, "Part1"),
                 "Hyperparameters.txt"))

    # Do not do any postprocessing if returning probabilities (because postprocessing only works on binary)
    if get_probs:
        bundle_specific_postprocessing = False
        postprocess = False

    Config = exp_utils.get_correct_labels_type(Config)
    Config.VERBOSE = verbose
    Config.TRAIN = False
    Config.TEST = False
    Config.SEGMENT = False
    Config.GET_PROBS = get_probs
    Config.LOAD_WEIGHTS = True
    Config.DROPOUT_SAMPLING = dropout_sampling
    Config.THRESHOLD = threshold
    Config.NR_CPUS = nr_cpus
    Config.INPUT_DIM = dataset_specific_utils.get_correct_input_dim(Config)
    Config.RESET_LAST_LAYER = False

    if Config.EXPERIMENT_TYPE == "tract_segmentation" and bundle_specific_postprocessing:
        Config.GET_PROBS = True

    if manual_exp_name is not None and Config.EXPERIMENT_TYPE != "peak_regression":
        Config.WEIGHTS_PATH = exp_utils.get_best_weights_path(
            join(C.EXP_PATH, manual_exp_name), True)
    else:
        if tract_definition == "TractQuerier+":
            if input_type == "peaks":
                if Config.EXPERIMENT_TYPE == "tract_segmentation" and Config.DROPOUT_SAMPLING:
                    Config.WEIGHTS_PATH = join(
                        C.WEIGHTS_DIR,
                        "pretrained_weights_tract_segmentation_v3.npz")
                elif Config.EXPERIMENT_TYPE == "tract_segmentation":
                    Config.WEIGHTS_PATH = join(
                        C.WEIGHTS_DIR,
                        "pretrained_weights_tract_segmentation_v3.npz")
                elif Config.EXPERIMENT_TYPE == "endings_segmentation":
                    Config.WEIGHTS_PATH = join(
                        C.WEIGHTS_DIR,
                        "pretrained_weights_endings_segmentation_v4.npz")
                elif Config.EXPERIMENT_TYPE == "dm_regression":
                    Config.WEIGHTS_PATH = join(
                        C.WEIGHTS_DIR,
                        "pretrained_weights_dm_regression_v2.npz")
            else:  # T1
                if Config.EXPERIMENT_TYPE == "tract_segmentation":
                    Config.WEIGHTS_PATH = join(
                        C.NETWORK_DRIVE,
                        "hcp_exp_nodes/x_Pretrained_TractSeg_Models",
                        "TractSeg_T1_125mm_DAugAll", "best_weights_ep142.npz")
                elif Config.EXPERIMENT_TYPE == "endings_segmentation":
                    Config.WEIGHTS_PATH = join(
                        C.WEIGHTS_DIR,
                        "pretrained_weights_endings_segmentation_v1.npz")
                elif Config.EXPERIMENT_TYPE == "peak_regression":
                    Config.WEIGHTS_PATH = join(
                        C.WEIGHTS_DIR,
                        "pretrained_weights_peak_regression_v1.npz")
        else:  # xtract
            if Config.EXPERIMENT_TYPE == "tract_segmentation":
                Config.WEIGHTS_PATH = join(
                    C.WEIGHTS_DIR,
                    "pretrained_weights_tract_segmentation_xtract_v1.npz")
            elif Config.EXPERIMENT_TYPE == "dm_regression":
                Config.WEIGHTS_PATH = join(
                    C.WEIGHTS_DIR,
                    "pretrained_weights_dm_regression_xtract_v1.npz")
            else:
                raise ValueError(
                    "bundle_definition xtract not supported in combination with this output type"
                )

    if Config.VERBOSE:
        print("Hyperparameters:")
        exp_utils.print_Configs(Config)

    data = np.nan_to_num(data)

    #runtime on HCP data: 0.9s
    data, seg_None, bbox, original_shape = data_utils.crop_to_nonzero(data)
    # runtime on HCP data: 0.5s
    data, transformation = data_utils.pad_and_scale_img_to_square_img(
        data, target_size=Config.INPUT_DIM[0], nr_cpus=nr_cpus)

    if Config.EXPERIMENT_TYPE == "tract_segmentation" or Config.EXPERIMENT_TYPE == "endings_segmentation" or \
            Config.EXPERIMENT_TYPE == "dm_regression":
        print("Loading weights from: {}".format(Config.WEIGHTS_PATH))
        Config.NR_OF_CLASSES = len(
            dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:])
        utils.download_pretrained_weights(
            experiment_type=Config.EXPERIMENT_TYPE,
            dropout_sampling=Config.DROPOUT_SAMPLING,
            tract_definition=tract_definition)
        model = BaseModel(Config, inference=True)
        if single_orientation:  # mainly needed for testing because of less RAM requirements
            data_loder_inference = DataLoaderInference(Config, data=data)
            if Config.DROPOUT_SAMPLING or Config.EXPERIMENT_TYPE == "dm_regression" or Config.GET_PROBS:
                seg, _ = trainer.predict_img(Config,
                                             model,
                                             data_loder_inference,
                                             probs=True,
                                             scale_to_world_shape=False,
                                             only_prediction=True,
                                             batch_size=inference_batch_size,
                                             unit_test=unit_test)
            else:
                seg, _ = trainer.predict_img(Config,
                                             model,
                                             data_loder_inference,
                                             probs=False,
                                             scale_to_world_shape=False,
                                             only_prediction=True,
                                             batch_size=inference_batch_size)
        else:
            seg_xyz, _ = direction_merger.get_seg_single_img_3_directions(
                Config,
                model,
                data=data,
                scale_to_world_shape=False,
                only_prediction=True,
                batch_size=inference_batch_size)
            if Config.DROPOUT_SAMPLING or Config.EXPERIMENT_TYPE == "dm_regression" or Config.GET_PROBS:
                seg = direction_merger.mean_fusion(Config.THRESHOLD,
                                                   seg_xyz,
                                                   probs=True)
            else:
                seg = direction_merger.mean_fusion(Config.THRESHOLD,
                                                   seg_xyz,
                                                   probs=False)

    elif Config.EXPERIMENT_TYPE == "peak_regression":
        weights = {
            "Part1": "pretrained_weights_peak_regression_part1_v2.npz",
            "Part2": "pretrained_weights_peak_regression_part2_v2.npz",
            "Part3": "pretrained_weights_peak_regression_part3_v2.npz",
            "Part4": "pretrained_weights_peak_regression_part4_v2.npz",
        }
        if peak_regression_part == "All":
            parts = ["Part1", "Part2", "Part3", "Part4"]
            seg_all = np.zeros((data.shape[0], data.shape[1], data.shape[2],
                                Config.NR_OF_CLASSES * 3))
        else:
            parts = [peak_regression_part]
            Config.CLASSES = "All_" + peak_regression_part
            Config.NR_OF_CLASSES = 3 * len(
                dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:])

        for idx, part in enumerate(parts):
            if manual_exp_name is not None:
                manual_exp_name_peaks = exp_utils.get_manual_exp_name_peaks(
                    manual_exp_name, part)
                Config.WEIGHTS_PATH = exp_utils.get_best_weights_path(
                    join(C.EXP_PATH, manual_exp_name_peaks), True)
            else:
                Config.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, weights[part])
            print("Loading weights from: {}".format(Config.WEIGHTS_PATH))
            Config.CLASSES = "All_" + part
            Config.NR_OF_CLASSES = 3 * len(
                dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:])
            utils.download_pretrained_weights(
                experiment_type=Config.EXPERIMENT_TYPE,
                dropout_sampling=Config.DROPOUT_SAMPLING,
                part=part,
                tract_definition=tract_definition)
            model = BaseModel(Config, inference=True)

            if single_orientation:
                data_loder_inference = DataLoaderInference(Config, data=data)
                seg, _ = trainer.predict_img(Config,
                                             model,
                                             data_loder_inference,
                                             probs=True,
                                             scale_to_world_shape=False,
                                             only_prediction=True,
                                             batch_size=inference_batch_size)
            else:
                # 3 dir for Peaks -> bad results
                seg_xyz, _ = direction_merger.get_seg_single_img_3_directions(
                    Config,
                    model,
                    data=data,
                    scale_to_world_shape=False,
                    only_prediction=True,
                    batch_size=inference_batch_size)
                seg = direction_merger.mean_fusion_peaks(seg_xyz,
                                                         nr_cpus=nr_cpus)

            if peak_regression_part == "All":
                seg_all[:, :, :,
                        (idx *
                         Config.NR_OF_CLASSES):(idx * Config.NR_OF_CLASSES +
                                                Config.NR_OF_CLASSES)] = seg

        if peak_regression_part == "All":
            Config.CLASSES = "All"
            Config.NR_OF_CLASSES = 3 * len(
                dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:])
            seg = seg_all

    if Config.EXPERIMENT_TYPE == "tract_segmentation" and bundle_specific_postprocessing and not dropout_sampling:
        # Runtime ~4s
        seg = img_utils.bundle_specific_postprocessing(
            seg,
            dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:])

    # runtime on HCP data: 5.1s
    seg = data_utils.cut_and_scale_img_back_to_original_img(seg,
                                                            transformation,
                                                            nr_cpus=nr_cpus)
    # runtime on HCP data: 1.6s
    seg = data_utils.add_original_zero_padding_again(seg, bbox, original_shape,
                                                     Config.NR_OF_CLASSES)

    if Config.EXPERIMENT_TYPE == "peak_regression":
        seg = peak_utils.mask_and_normalize_peaks(
            seg,
            tract_segmentations_path,
            dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:],
            TOM_dilation,
            nr_cpus=nr_cpus)

    if Config.EXPERIMENT_TYPE == "tract_segmentation" and postprocess and not dropout_sampling:
        # Runtime ~7s for 1.25mm resolution
        # Runtime ~1.5s for  2mm resolution
        st = time.time()
        seg = img_utils.postprocess_segmentations(
            seg,
            dataset_specific_utils.get_bundle_names(Config.CLASSES)[1:],
            blob_thr=blob_size_thr,
            hole_closing=None)

    exp_utils.print_verbose(
        Config.VERBOSE, "Took {}s".format(round(time.time() - start_time, 2)))
    return seg