示例#1
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(
        parser,
        [args.moving_tractogram, args.target_file, args.transformation])
    assert_outputs_exist(parser, args, args.out_tractogram)

    moving_sft = load_tractogram_with_reference(parser,
                                                args,
                                                args.moving_tractogram,
                                                bbox_check=False)

    transfo = np.loadtxt(args.transformation)
    if args.inverse:
        transfo = np.linalg.inv(transfo)

    moved_streamlines = transform_streamlines(moving_sft.streamlines, transfo)
    new_sft = StatefulTractogram(
        moved_streamlines,
        args.target_file,
        Space.RASMM,
        data_per_point=moving_sft.data_per_point,
        data_per_streamline=moving_sft.data_per_streamline)

    if args.remove_invalid:
        ori_len = len(new_sft)
        new_sft.remove_invalid_streamlines()
        logging.warning('Removed {} invalid streamlines.'.format(ori_len -
                                                                 len(new_sft)))
        save_tractogram(new_sft, args.out_tractogram)
    elif args.keep_invalid:
        if not new_sft.is_bbox_in_vox_valid():
            logging.warning('Saving tractogram with invalid streamlines.')
        save_tractogram(new_sft, args.out_tractogram, bbox_valid_check=False)
    else:
        save_tractogram(new_sft, args.out_tractogram)
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.moving_tractogram, args.target_file,
                                 args.deformation])
    assert_outputs_exist(parser, args, args.out_tractogram)

    sft = load_tractogram_with_reference(parser, args, args.moving_tractogram,
                                         bbox_check=False)

    deformation = nib.load(args.deformation)
    deformation_data = np.squeeze(deformation.get_fdata())

    if not is_header_compatible(sft, deformation):
        parser.error('Input tractogram/reference do not have the same spatial '
                     'attribute as the deformation field.')

    # Warning: Apply warp in-place
    moved_streamlines = warp_streamlines(sft, deformation_data)
    new_sft = StatefulTractogram(moved_streamlines, args.target_file,
                                 Space.RASMM,
                                 data_per_point=sft.data_per_point,
                                 data_per_streamline=sft.data_per_streamline)

    if args.remove_invalid:
        ori_len = len(new_sft)
        new_sft.remove_invalid_streamlines()
        logging.warning('Removed {} invalid streamlines.'.format(
            ori_len - len(new_sft)))
        save_tractogram(new_sft, args.out_tractogram)
    elif args.keep_invalid:
        if not new_sft.is_bbox_in_vox_valid():
            logging.warning('Saving tractogram with invalid streamlines.')
        save_tractogram(new_sft, args.out_tractogram, bbox_valid_check=False)
    else:
        save_tractogram(new_sft, args.out_tractogram)
示例#3
0
def direct_streamline_norm(
    streams,
    fa_path,
    ap_path,
    dir_path,
    track_type,
    target_samples,
    conn_model,
    network,
    node_size,
    dens_thresh,
    ID,
    roi,
    min_span_tree,
    disp_filt,
    parc,
    prune,
    atlas,
    labels_im_file,
    uatlas,
    labels,
    coords,
    norm,
    binary,
    atlas_mni,
    basedir_path,
    curv_thr_list,
    step_list,
    directget,
    min_length,
    error_margin,
    t1_aligned_mni
):
    """
    A Function to perform normalization of streamlines tracked in native
    diffusion space to an MNI-space template.

    Parameters
    ----------
    streams : str
        File path to save streamline array sequence in .trk format.
    fa_path : str
        File path to FA Nifti1Image.
    ap_path : str
        File path to the anisotropic power Nifti1Image.
    dir_path : str
        Path to directory containing subject derivative data for a given
        pynets run.
    track_type : str
        Tracking algorithm used (e.g. 'local' or 'particle').
    target_samples : int
        Total number of streamline samples specified to generate streams.
    conn_model : str
        Connectivity reconstruction method (e.g. 'csa', 'tensor', 'csd').
    network : str
        Resting-state network based on Yeo-7 and Yeo-17 naming (e.g. 'Default')
        used to filter nodes in the study of brain subgraphs.
    node_size : int
        Spherical centroid node size in the case that coordinate-based
        centroids are used as ROI's for tracking.
    dens_thresh : bool
        Indicates whether a target graph density is to be used as the basis for
        thresholding.
    ID : str
        A subject id or other unique identifier.
    roi : str
        File path to binarized/boolean region-of-interest Nifti1Image file.
    min_span_tree : bool
        Indicates whether local thresholding from the Minimum Spanning Tree
        should be used.
    disp_filt : bool
        Indicates whether local thresholding using a disparity filter and
        'backbone network' should be used.
    parc : bool
        Indicates whether to use parcels instead of coordinates as ROI nodes.
    prune : bool
        Indicates whether to prune final graph of disconnected nodes/isolates.
    atlas : str
        Name of atlas parcellation used.
    labels_im_file : str
        File path to atlas parcellation Nifti1Image aligned to dwi space.
    uatlas : str
        File path to atlas parcellation Nifti1Image in MNI template space.
    labels : list
        List of string labels corresponding to graph nodes.
    coords : list
        List of (x, y, z) tuples corresponding to a coordinate atlas used or
        which represent the center-of-mass of each parcellation node.
    norm : int
        Indicates method of normalizing resulting graph.
    binary : bool
        Indicates whether to binarize resulting graph edges to form an
        unweighted graph.
    atlas_mni : str
        File path to atlas parcellation Nifti1Image in T1w-warped MNI space.
    basedir_path : str
        Path to directory to output direct-streamline normalized temp files
        and outputs.
    curv_thr_list : list
        List of integer curvature thresholds used to perform ensemble tracking.
    step_list : list
        List of float step-sizes used to perform ensemble tracking.
    directget : str
        The statistical approach to tracking. Options are: det (deterministic),
        closest (clos), boot (bootstrapped), and prob (probabilistic).
    min_length : int
        Minimum fiber length threshold in mm to restrict tracking.
    t1_aligned_mni : str
        File path to the T1w Nifti1Image in template MNI space.

    Returns
    -------
    streams_warp : str
        File path to normalized streamline array sequence in .trk format.
    dir_path : str
        Path to directory containing subject derivative data for a given
        pynets run.
    track_type : str
        Tracking algorithm used (e.g. 'local' or 'particle').
    target_samples : int
        Total number of streamline samples specified to generate streams.
    conn_model : str
        Connectivity reconstruction method (e.g. 'csa', 'tensor', 'csd').
    network : str
        Resting-state network based on Yeo-7 and Yeo-17 naming (e.g. 'Default')
        used to filter nodes in the study of brain subgraphs.
    node_size : int
        Spherical centroid node size in the case that coordinate-based
        centroids are used as ROI's for tracking.
    dens_thresh : bool
        Indicates whether a target graph density is to be used as the basis for
        thresholding.
    ID : str
        A subject id or other unique identifier.
    roi : str
        File path to binarized/boolean region-of-interest Nifti1Image file.
    min_span_tree : bool
        Indicates whether local thresholding from the Minimum Spanning Tree
        should be used.
    disp_filt : bool
        Indicates whether local thresholding using a disparity filter and
        'backbone network' should be used.
    parc : bool
        Indicates whether to use parcels instead of coordinates as ROI nodes.
    prune : bool
        Indicates whether to prune final graph of disconnected nodes/isolates.
    atlas : str
        Name of atlas parcellation used.
    uatlas : str
        File path to atlas parcellation Nifti1Image in MNI template space.
    labels : list
        List of string labels corresponding to graph nodes.
    coords : list
        List of (x, y, z) tuples corresponding to a coordinate atlas used or
        which represent the center-of-mass of each parcellation node.
    norm : int
        Indicates method of normalizing resulting graph.
    binary : bool
        Indicates whether to binarize resulting graph edges to form an
        unweighted graph.
    atlas_mni : str
        File path to atlas parcellation Nifti1Image in T1w-warped MNI space.
    directget : str
        The statistical approach to tracking. Options are: det
        (deterministic), closest (clos), boot (bootstrapped),
        and prob (probabilistic).
    warped_fa : str
        File path to MNI-space warped FA Nifti1Image.
    min_length : int
        Minimum fiber length threshold in mm to restrict tracking.

    References
    ----------
    .. [1] Greene, C., Cieslak, M., & Grafton, S. T. (2017). Effect of
      different spatial normalization approaches on tractography and structural
      brain networks. Network Neuroscience, 1-19.
    """
    import sys
    import gc
    from dipy.tracking.streamline import transform_streamlines
    from pynets.registration import reg_utils as regutils
    # from pynets.plotting import plot_gen
    import pkg_resources
    import yaml
    import os.path as op
    from pynets.registration.reg_utils import vdc
    from nilearn.image import resample_to_img
    from dipy.io.streamline import load_tractogram
    from dipy.tracking import utils
    from dipy.tracking._utils import _mapping_to_voxel
    from dipy.io.stateful_tractogram import Space, StatefulTractogram, Origin
    from dipy.io.streamline import save_tractogram

    # from pynets.core.utils import missing_elements

    with open(
        pkg_resources.resource_filename("pynets", "runconfig.yaml"), "r"
    ) as stream:
        try:
            hardcoded_params = yaml.load(stream)
            run_dsn = hardcoded_params['tracking']["DSN"][0]
        except FileNotFoundError as e:
            import sys
            print(e, "Failed to parse runconfig.yaml")
            exit(1)

    stream.close()

    if run_dsn is True:
        dsn_dir = f"{basedir_path}/dmri_reg/DSN"
        if not op.isdir(dsn_dir):
            os.mkdir(dsn_dir)

        namer_dir = f"{dir_path}/tractography"
        if not op.isdir(namer_dir):
            os.mkdir(namer_dir)

        atlas_img = nib.load(labels_im_file)

        # Run SyN and normalize streamlines
        fa_img = nib.load(fa_path)
        vox_size = fa_img.header.get_zooms()[0]
        template_path = pkg_resources.resource_filename(
            "pynets", f"templates/FA_{int(vox_size)}mm.nii.gz"
        )

        if sys.platform.startswith('win') is False:
            try:
                template_img = nib.load(template_path)
            except indexed_gzip.ZranError as e:
                print(e,
                      f"\nCannot load FA template. Do you have git-lfs "
                      f"installed?")
                sys.exit(1)
        else:
            try:
                template_img = nib.load(template_path)
            except ImportError as e:
                print(e, f"\nCannot load FA template. Do you have git-lfs "
                      f"installed?")
                sys.exit(1)

        uatlas_mni_img = nib.load(atlas_mni)
        t1_aligned_mni_img = nib.load(t1_aligned_mni)
        brain_mask = np.asarray(t1_aligned_mni_img.dataobj).astype("bool")

        streams_mni = "%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s" % (
            namer_dir,
            "/streamlines_mni_",
            "%s" % (network + "_" if network is not None else ""),
            "%s" % (op.basename(roi).split(".")[0] + "_" if roi is not None
                    else ""),
            conn_model,
            "_",
            target_samples,
            "%s"
            % (
                "%s%s" % ("_" + str(node_size), "mm_")
                if ((node_size != "parc") and (node_size is not None))
                else "_"
            ),
            "curv",
            str(curv_thr_list).replace(", ", "_"),
            "step",
            str(step_list).replace(", ", "_"),
            "tracktype-",
            track_type,
            "_directget-",
            directget,
            "_minlength-",
            min_length,
            "_tol-",
            error_margin,
            ".trk",
        )

        density_mni = "%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s" % (
            namer_dir,
            "/density_map_mni_",
            "%s" % (network + "_" if network is not None else ""),
            "%s" % (op.basename(roi).split(".")[0] + "_" if roi is not None
                    else ""),
            conn_model,
            "_",
            target_samples,
            "%s"
            % (
                "%s%s" % ("_" + str(node_size), "mm_")
                if ((node_size != "parc") and (node_size is not None))
                else "_"
            ),
            "curv",
            str(curv_thr_list).replace(", ", "_"),
            "step",
            str(step_list).replace(", ", "_"),
            "tracktype-",
            track_type,
            "_directget-",
            directget,
            "_minlength-",
            min_length,
            "_tol-",
            error_margin,
            ".nii.gz",
        )

        # streams_warp_png = '/tmp/dsn.png'

        # SyN FA->Template
        [mapping, affine_map, warped_fa] = regutils.wm_syn(
            template_path, fa_path, t1_aligned_mni, ap_path, dsn_dir
        )

        tractogram = load_tractogram(
            streams,
            fa_img,
            to_origin=Origin.NIFTI,
            to_space=Space.VOXMM,
            bbox_valid_check=False,
        )

        fa_img.uncache()
        streamlines = tractogram.streamlines
        warped_fa_img = nib.load(warped_fa)
        warped_fa_affine = warped_fa_img.affine
        warped_fa_shape = warped_fa_img.shape

        streams_in_curr_grid = transform_streamlines(
            streamlines, warped_fa_affine)

        # Create isocenter mapping where we anchor the origin transformation
        # affine to the corner of the FOV by scaling x, y, z offsets according
        # to a multiplicative van der Corput sequence with a base value equal
        # to the voxel resolution
        [x_mul, y_mul, z_mul] = [vdc(i, vox_size) for i in range(1, 4)]

        ref_grid_aff = vox_size * np.eye(4)
        ref_grid_aff[3][3] = 1

        streams_final_filt = []
        i = 0
        # Test for various types of voxel-grid configurations
        combs = [(-x_mul, -y_mul, -z_mul), (-x_mul, -y_mul, z_mul),
                 (-x_mul, y_mul, -z_mul), (x_mul, -y_mul, -z_mul),
                 (x_mul, y_mul, z_mul)]
        while len(streams_final_filt)/len(streams_in_curr_grid) < 0.90:
            print(f"Warping streamlines to MNI space. Attempt {i}...")
            print(len(streams_final_filt)/len(streams_in_curr_grid))
            adjusted_affine = affine_map.affine.copy()
            if i > len(combs) - 1:
                raise ValueError('DSN failed. Header orientation '
                                 'information may be corrupted. '
                                 'Is your dataset oblique?')

            adjusted_affine[0][3] = adjusted_affine[0][3] * combs[i][0]
            adjusted_affine[1][3] = adjusted_affine[1][3] * combs[i][1]
            adjusted_affine[2][3] = adjusted_affine[2][3] * combs[i][2]

            streams_final_filt = regutils.warp_streamlines(adjusted_affine,
                                                           ref_grid_aff,
                                                           mapping,
                                                           warped_fa_img,
                                                           streams_in_curr_grid,
                                                           brain_mask)

            i += 1

        # Remove streamlines with negative voxel indices
        lin_T, offset = _mapping_to_voxel(np.eye(4))
        streams_final_filt_final = []
        for sl in streams_final_filt:
            inds = np.dot(sl, lin_T)
            inds += offset
            if not inds.min().round(decimals=6) < 0:
                streams_final_filt_final.append(sl)

        # Save streamlines
        stf = StatefulTractogram(
            streams_final_filt_final,
            reference=uatlas_mni_img,
            space=Space.VOXMM,
            origin=Origin.NIFTI,
        )
        stf.remove_invalid_streamlines()
        streams_final_filt_final = stf.streamlines
        save_tractogram(stf, streams_mni, bbox_valid_check=True)
        warped_fa_img.uncache()

        # DSN QC plotting
        # plot_gen.show_template_bundles(streams_final_filt_final, atlas_mni,
        # streams_warp_png) plot_gen.show_template_bundles(streamlines,
        # fa_path, streams_warp_png)

        # Create and save MNI density map
        nib.save(
            nib.Nifti1Image(
                utils.density_map(
                    streams_final_filt_final,
                    affine=np.eye(4),
                    vol_dims=warped_fa_shape),
                warped_fa_affine,
            ),
            density_mni,
        )

        # Map parcellation from native space back to MNI-space and create an
        # 'uncertainty-union' parcellation with original mni-space uatlas
        warped_uatlas = affine_map.transform_inverse(
            mapping.transform(
                np.asarray(atlas_img.dataobj).astype("int"),
                interpolation="nearestneighbour",
            ),
            interp="nearest",
        )
        atlas_img.uncache()
        warped_uatlas_img_res_data = np.asarray(
            resample_to_img(
                nib.Nifti1Image(warped_uatlas, affine=warped_fa_affine),
                uatlas_mni_img,
                interpolation="nearest",
                clip=False,
            ).dataobj
        )
        uatlas_mni_data = np.asarray(uatlas_mni_img.dataobj)
        uatlas_mni_img.uncache()
        overlap_mask = np.invert(
            warped_uatlas_img_res_data.astype("bool") *
            uatlas_mni_data.astype("bool"))
        os.makedirs(f"{dir_path}/parcellations", exist_ok=True)
        atlas_mni = f"{dir_path}/parcellations/" \
                    f"{op.basename(uatlas).split('.nii')[0]}_liberal.nii.gz"

        nib.save(
            nib.Nifti1Image(
                warped_uatlas_img_res_data * overlap_mask.astype("int")
                + uatlas_mni_data * overlap_mask.astype("int")
                + np.invert(overlap_mask).astype("int")
                * warped_uatlas_img_res_data.astype("int"),
                affine=uatlas_mni_img.affine,
            ),
            atlas_mni,
        )

        del (
            tractogram,
            streamlines,
            warped_uatlas_img_res_data,
            uatlas_mni_data,
            overlap_mask,
            stf,
            streams_final_filt_final,
            streams_final_filt,
            streams_in_curr_grid,
            brain_mask,
        )

        gc.collect()

        assert len(coords) == len(labels)

    else:
        print(
            "Skipping Direct Streamline Normalization (DSN). Will proceed to "
            "define fiber connectivity in native diffusion space...")
        streams_mni = streams
        warped_fa = fa_path
        atlas_mni = labels_im_file

    return (
        streams_mni,
        dir_path,
        track_type,
        target_samples,
        conn_model,
        network,
        node_size,
        dens_thresh,
        ID,
        roi,
        min_span_tree,
        disp_filt,
        parc,
        prune,
        atlas,
        uatlas,
        labels,
        coords,
        norm,
        binary,
        atlas_mni,
        directget,
        warped_fa,
        min_length,
        error_margin
    )
示例#4
0
def transform_warp_streamlines(sft,
                               linear_transfo,
                               target,
                               inverse=False,
                               deformation_data=None,
                               remove_invalid=True,
                               cut_invalid=False):
    # TODO rename transform_warp_sft
    """ Transform tractogram using a affine Subsequently apply a warp from
    antsRegistration (optional).
    Remove/Cut invalid streamlines to preserve sft validity.

    Parameters
    ----------
    sft: StatefulTractogram
        Stateful tractogram object containing the streamlines to transform.
    linear_transfo: numpy.ndarray
        Linear transformation matrix to apply to the tractogram.
    target: Nifti filepath, image object, header
        Final reference for the tractogram after registration.
    inverse: boolean
        Apply the inverse linear transformation.
    deformation_data: np.ndarray
        4D array containing a 3D displacement vector in each voxel.

    remove_invalid: boolean
        Remove the streamlines landing out of the bounding box.
    cut_invalid: boolean
        Cut invalid streamlines rather than removing them. Keep the longest
        segment only.

    Return
    ----------
    new_sft : StatefulTractogram

    """
    sft.to_rasmm()
    sft.to_center()
    if inverse:
        linear_transfo = np.linalg.inv(linear_transfo)

    streamlines = transform_streamlines(sft.streamlines, linear_transfo)

    if deformation_data is not None:
        affine, _, _, _ = get_reference_info(target)

        # Because of duplication, an iteration over chunks of points is
        # necessary for a big dataset (especially if not compressed)
        streamlines = ArraySequence(streamlines)
        nb_points = len(streamlines._data)
        cur_position = 0
        chunk_size = 1000000
        nb_iteration = int(np.ceil(nb_points / chunk_size))
        inv_affine = np.linalg.inv(affine)

        while nb_iteration > 0:
            max_position = min(cur_position + chunk_size, nb_points)
            points = streamlines._data[cur_position:max_position]

            # To access the deformation information, we need to go in VOX space
            # No need for corner shift since we are doing interpolation
            cur_points_vox = np.array(transform_streamlines(
                points, inv_affine)).T

            x_def = map_coordinates(deformation_data[..., 0],
                                    cur_points_vox.tolist(),
                                    order=1)
            y_def = map_coordinates(deformation_data[..., 1],
                                    cur_points_vox.tolist(),
                                    order=1)
            z_def = map_coordinates(deformation_data[..., 2],
                                    cur_points_vox.tolist(),
                                    order=1)

            # ITK is in LPS and nibabel is in RAS, a flip is necessary for ANTs
            final_points = np.array([-1 * x_def, -1 * y_def, z_def])
            final_points += np.array(points).T

            streamlines._data[cur_position:max_position] = final_points.T
            cur_position = max_position
            nb_iteration -= 1

    new_sft = StatefulTractogram(streamlines,
                                 target,
                                 Space.RASMM,
                                 data_per_point=sft.data_per_point,
                                 data_per_streamline=sft.data_per_streamline)
    if cut_invalid:
        new_sft, _ = cut_invalid_streamlines(new_sft)
    elif remove_invalid:
        new_sft.remove_invalid_streamlines()

    return new_sft
def empty_remove_invalid():
    sft = StatefulTractogram([], filepath_dix['gs.nii'], Space.VOX)
    sft.remove_invalid_streamlines()
    assert_array_equal([], sft.streamlines.data)
示例#6
0
    def _run_interface(self, runtime):
        import gc
        import os
        import time
        import os.path as op
        from dipy.io import load_pickle
        from colorama import Fore, Style
        from dipy.data import get_sphere
        from pynets.core import utils
        from pynets.core.utils import load_runconfig
        from pynets.dmri.estimation import reconstruction
        from pynets.dmri.track import (
            create_density_map,
            track_ensemble,
        )
        from dipy.io.stateful_tractogram import Space, StatefulTractogram, \
            Origin
        from dipy.io.streamline import save_tractogram
        from nipype.utils.filemanip import copyfile, fname_presuffix

        hardcoded_params = load_runconfig()
        use_life = hardcoded_params['tracking']["use_life"][0]
        roi_neighborhood_tol = hardcoded_params['tracking'][
            "roi_neighborhood_tol"][0]
        sphere = hardcoded_params['tracking']["sphere"][0]
        target_samples = hardcoded_params['tracking']["tracking_samples"][0]

        dir_path = utils.do_dir_path(self.inputs.atlas,
                                     os.path.dirname(self.inputs.dwi_file))

        namer_dir = "{}/tractography".format(dir_path)
        if not os.path.isdir(namer_dir):
            os.makedirs(namer_dir, exist_ok=True)

        # Load diffusion data
        dwi_file_tmp_path = fname_presuffix(self.inputs.dwi_file,
                                            suffix="_tmp",
                                            newpath=runtime.cwd)
        copyfile(self.inputs.dwi_file,
                 dwi_file_tmp_path,
                 copy=True,
                 use_hardlink=False)

        dwi_img = nib.load(dwi_file_tmp_path, mmap=True)
        dwi_data = dwi_img.get_fdata(dtype=np.float32)

        # Load FA data
        fa_file_tmp_path = fname_presuffix(self.inputs.fa_path,
                                           suffix="_tmp",
                                           newpath=runtime.cwd)
        copyfile(self.inputs.fa_path,
                 fa_file_tmp_path,
                 copy=True,
                 use_hardlink=False)

        fa_img = nib.load(fa_file_tmp_path, mmap=True)

        labels_im_file_tmp_path = fname_presuffix(self.inputs.labels_im_file,
                                                  suffix="_tmp",
                                                  newpath=runtime.cwd)
        copyfile(self.inputs.labels_im_file,
                 labels_im_file_tmp_path,
                 copy=True,
                 use_hardlink=False)

        # Load B0 mask
        B0_mask_tmp_path = fname_presuffix(self.inputs.B0_mask,
                                           suffix="_tmp",
                                           newpath=runtime.cwd)
        copyfile(self.inputs.B0_mask,
                 B0_mask_tmp_path,
                 copy=True,
                 use_hardlink=False)

        streams = "%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s" % (
            runtime.cwd,
            "/streamlines_",
            "%s" % (self.inputs.subnet +
                    "_" if self.inputs.subnet is not None else ""),
            "%s" % (op.basename(self.inputs.roi).split(".")[0] +
                    "_" if self.inputs.roi is not None else ""),
            self.inputs.conn_model,
            "_",
            target_samples,
            "_",
            "%s" % ("%s%s" % (self.inputs.node_radius, "mm_") if
                    ((self.inputs.node_radius != "parc") and
                     (self.inputs.node_radius is not None)) else "parc_"),
            "curv-",
            str(self.inputs.curv_thr_list).replace(", ", "_"),
            "_step-",
            str(self.inputs.step_list).replace(", ", "_"),
            "_traversal-",
            self.inputs.traversal,
            "_minlength-",
            self.inputs.min_length,
            ".trk",
        )

        if os.path.isfile(f"{namer_dir}/{op.basename(streams)}"):
            from dipy.io.streamline import load_tractogram
            copyfile(
                f"{namer_dir}/{op.basename(streams)}",
                streams,
                copy=True,
                use_hardlink=False,
            )
            tractogram = load_tractogram(
                streams,
                fa_img,
                bbox_valid_check=False,
            )

            streamlines = tractogram.streamlines

            # Create streamline density map
            try:
                [dir_path, dm_path] = create_density_map(
                    fa_img,
                    dir_path,
                    streamlines,
                    self.inputs.conn_model,
                    self.inputs.node_radius,
                    self.inputs.curv_thr_list,
                    self.inputs.step_list,
                    self.inputs.subnet,
                    self.inputs.roi,
                    self.inputs.traversal,
                    self.inputs.min_length,
                    namer_dir,
                )
            except BaseException:
                print('Density map failed. Check tractography output.')
                dm_path = None

            del streamlines, tractogram
            fa_img.uncache()
            dwi_img.uncache()
            gc.collect()
            self._results["dm_path"] = dm_path
            self._results["streams"] = streams
            recon_path = None
        else:
            # Fit diffusion model
            # Save reconstruction to .npy
            recon_path = "%s%s%s%s%s%s%s%s" % (
                runtime.cwd,
                "/reconstruction_",
                "%s" % (self.inputs.subnet +
                        "_" if self.inputs.subnet is not None else ""),
                "%s" % (op.basename(self.inputs.roi).split(".")[0] +
                        "_" if self.inputs.roi is not None else ""),
                self.inputs.conn_model,
                "_",
                "%s" % ("%s%s" % (self.inputs.node_radius, "mm") if
                        ((self.inputs.node_radius != "parc") and
                         (self.inputs.node_radius is not None)) else "parc"),
                ".hdf5",
            )

            gtab_file_tmp_path = fname_presuffix(self.inputs.gtab_file,
                                                 suffix="_tmp",
                                                 newpath=runtime.cwd)
            copyfile(self.inputs.gtab_file,
                     gtab_file_tmp_path,
                     copy=True,
                     use_hardlink=False)

            gtab = load_pickle(gtab_file_tmp_path)

            # Only re-run the reconstruction if we have to
            if not os.path.isfile(f"{namer_dir}/{op.basename(recon_path)}"):
                import h5py
                model = reconstruction(
                    self.inputs.conn_model,
                    gtab,
                    dwi_data,
                    B0_mask_tmp_path,
                )[0]
                with h5py.File(recon_path, 'w') as hf:
                    hf.create_dataset("reconstruction",
                                      data=model.astype('float32'),
                                      dtype='f4')
                hf.close()

                copyfile(
                    recon_path,
                    f"{namer_dir}/{op.basename(recon_path)}",
                    copy=True,
                    use_hardlink=False,
                )
                time.sleep(2)
                del model
            elif os.path.getsize(f"{namer_dir}/{op.basename(recon_path)}") > 0:
                print(f"Found existing reconstruction with "
                      f"{self.inputs.conn_model}. Loading...")
                copyfile(
                    f"{namer_dir}/{op.basename(recon_path)}",
                    recon_path,
                    copy=True,
                    use_hardlink=False,
                )
                time.sleep(5)
            else:
                import h5py
                model = reconstruction(
                    self.inputs.conn_model,
                    gtab,
                    dwi_data,
                    B0_mask_tmp_path,
                )[0]
                with h5py.File(recon_path, 'w') as hf:
                    hf.create_dataset("reconstruction",
                                      data=model.astype('float32'),
                                      dtype='f4')
                hf.close()

                copyfile(
                    recon_path,
                    f"{namer_dir}/{op.basename(recon_path)}",
                    copy=True,
                    use_hardlink=False,
                )
                time.sleep(5)
                del model
            dwi_img.uncache()
            del dwi_data

            # Load atlas wm-gm interface reduced version for seeding
            labels_im_file_tmp_path_wm_gm_int = fname_presuffix(
                self.inputs.labels_im_file_wm_gm_int,
                suffix="_tmp",
                newpath=runtime.cwd)
            copyfile(self.inputs.labels_im_file_wm_gm_int,
                     labels_im_file_tmp_path_wm_gm_int,
                     copy=True,
                     use_hardlink=False)

            t1w2dwi_tmp_path = fname_presuffix(self.inputs.t1w2dwi,
                                               suffix="_tmp",
                                               newpath=runtime.cwd)
            copyfile(self.inputs.t1w2dwi,
                     t1w2dwi_tmp_path,
                     copy=True,
                     use_hardlink=False)

            gm_in_dwi_tmp_path = fname_presuffix(self.inputs.gm_in_dwi,
                                                 suffix="_tmp",
                                                 newpath=runtime.cwd)
            copyfile(self.inputs.gm_in_dwi,
                     gm_in_dwi_tmp_path,
                     copy=True,
                     use_hardlink=False)

            vent_csf_in_dwi_tmp_path = fname_presuffix(
                self.inputs.vent_csf_in_dwi,
                suffix="_tmp",
                newpath=runtime.cwd)
            copyfile(self.inputs.vent_csf_in_dwi,
                     vent_csf_in_dwi_tmp_path,
                     copy=True,
                     use_hardlink=False)

            wm_in_dwi_tmp_path = fname_presuffix(self.inputs.wm_in_dwi,
                                                 suffix="_tmp",
                                                 newpath=runtime.cwd)
            copyfile(self.inputs.wm_in_dwi,
                     wm_in_dwi_tmp_path,
                     copy=True,
                     use_hardlink=False)

            if self.inputs.waymask:
                waymask_tmp_path = fname_presuffix(self.inputs.waymask,
                                                   suffix="_tmp",
                                                   newpath=runtime.cwd)
                copyfile(self.inputs.waymask,
                         waymask_tmp_path,
                         copy=True,
                         use_hardlink=False)
            else:
                waymask_tmp_path = None

            # Iteratively build a list of streamlines for each ROI while
            # tracking
            print(f"{Fore.GREEN}Target streamlines per iteration: "
                  f"{Fore.BLUE} "
                  f"{target_samples}")
            print(Style.RESET_ALL)
            print(f"{Fore.GREEN}Curvature threshold(s): {Fore.BLUE} "
                  f"{self.inputs.curv_thr_list}")
            print(Style.RESET_ALL)
            print(f"{Fore.GREEN}Step size(s): {Fore.BLUE} "
                  f"{self.inputs.step_list}")
            print(Style.RESET_ALL)
            print(f"{Fore.GREEN}Tracking type: {Fore.BLUE} "
                  f"{self.inputs.track_type}")
            print(Style.RESET_ALL)
            if self.inputs.traversal == "prob":
                print(f"{Fore.GREEN}Direction-getting type: {Fore.BLUE}"
                      f"Probabilistic")
            elif self.inputs.traversal == "clos":
                print(f"{Fore.GREEN}Direction-getting type: "
                      f"{Fore.BLUE}Closest Peak")
            elif self.inputs.traversal == "det":
                print(f"{Fore.GREEN}Direction-getting type: "
                      f"{Fore.BLUE}Deterministic Maximum")
            else:
                raise ValueError("Direction-getting type not recognized!")

            print(Style.RESET_ALL)

            # Commence Ensemble Tractography
            try:
                streamlines = track_ensemble(
                    target_samples, labels_im_file_tmp_path_wm_gm_int,
                    labels_im_file_tmp_path, recon_path, get_sphere(sphere),
                    self.inputs.traversal, self.inputs.curv_thr_list,
                    self.inputs.step_list,
                    self.inputs.track_type, self.inputs.maxcrossing,
                    int(roi_neighborhood_tol), self.inputs.min_length,
                    waymask_tmp_path, B0_mask_tmp_path, t1w2dwi_tmp_path,
                    gm_in_dwi_tmp_path, vent_csf_in_dwi_tmp_path,
                    wm_in_dwi_tmp_path, self.inputs.tiss_class)
                gc.collect()
            except BaseException as w:
                print(f"\n{Fore.RED}Tractography failed: {w}")
                print(Style.RESET_ALL)
                streamlines = None

            if streamlines is not None:
                # import multiprocessing
                # from pynets.core.utils import kill_process_family
                # return kill_process_family(int(
                # multiprocessing.current_process().pid))

                # Linear Fascicle Evaluation (LiFE)
                if use_life is True:
                    print('Using LiFE to evaluate streamline plausibility...')
                    from pynets.dmri.utils import \
                        evaluate_streamline_plausibility
                    dwi_img = nib.load(dwi_file_tmp_path)
                    dwi_data = dwi_img.get_fdata(dtype=np.float32)
                    orig_count = len(streamlines)

                    if self.inputs.waymask:
                        mask_data = nib.load(waymask_tmp_path).get_fdata(
                        ).astype('bool').astype('int')
                    else:
                        mask_data = nib.load(wm_in_dwi_tmp_path).get_fdata(
                        ).astype('bool').astype('int')
                    try:
                        streamlines = evaluate_streamline_plausibility(
                            dwi_data,
                            gtab,
                            mask_data,
                            streamlines,
                            sphere=sphere)
                    except BaseException:
                        print(f"Linear Fascicle Evaluation failed. "
                              f"Visually checking streamlines output "
                              f"{namer_dir}/{op.basename(streams)} is "
                              f"recommended.")
                    if len(streamlines) < 0.5 * orig_count:
                        raise ValueError('LiFE revealed no plausible '
                                         'streamlines in the tractogram!')
                    del dwi_data, mask_data

                # Save streamlines to trk
                stf = StatefulTractogram(streamlines,
                                         fa_img,
                                         origin=Origin.NIFTI,
                                         space=Space.VOXMM)
                stf.remove_invalid_streamlines()

                save_tractogram(
                    stf,
                    streams,
                )

                del stf

                copyfile(
                    streams,
                    f"{namer_dir}/{op.basename(streams)}",
                    copy=True,
                    use_hardlink=False,
                )

                # Create streamline density map
                try:
                    [dir_path, dm_path] = create_density_map(
                        dwi_img,
                        dir_path,
                        streamlines,
                        self.inputs.conn_model,
                        self.inputs.node_radius,
                        self.inputs.curv_thr_list,
                        self.inputs.step_list,
                        self.inputs.subnet,
                        self.inputs.roi,
                        self.inputs.traversal,
                        self.inputs.min_length,
                        namer_dir,
                    )
                except BaseException:
                    print('Density map failed. Check tractography output.')
                    dm_path = None

                del streamlines
                dwi_img.uncache()
                gc.collect()
                self._results["dm_path"] = dm_path
                self._results["streams"] = streams
            else:
                self._results["streams"] = None
                self._results["dm_path"] = None
            tmp_files = [
                gtab_file_tmp_path, wm_in_dwi_tmp_path, gm_in_dwi_tmp_path,
                vent_csf_in_dwi_tmp_path, t1w2dwi_tmp_path
            ]

            for j in tmp_files:
                if j is not None:
                    if os.path.isfile(j):
                        os.system(f"rm -f {j} &")

        self._results["track_type"] = self.inputs.track_type
        self._results["conn_model"] = self.inputs.conn_model
        self._results["dir_path"] = dir_path
        self._results["subnet"] = self.inputs.subnet
        self._results["node_radius"] = self.inputs.node_radius
        self._results["dens_thresh"] = self.inputs.dens_thresh
        self._results["ID"] = self.inputs.ID
        self._results["roi"] = self.inputs.roi
        self._results["min_span_tree"] = self.inputs.min_span_tree
        self._results["disp_filt"] = self.inputs.disp_filt
        self._results["parc"] = self.inputs.parc
        self._results["prune"] = self.inputs.prune
        self._results["atlas"] = self.inputs.atlas
        self._results["parcellation"] = self.inputs.parcellation
        self._results["labels"] = self.inputs.labels
        self._results["coords"] = self.inputs.coords
        self._results["norm"] = self.inputs.norm
        self._results["binary"] = self.inputs.binary
        self._results["atlas_t1w"] = self.inputs.atlas_t1w
        self._results["curv_thr_list"] = self.inputs.curv_thr_list
        self._results["step_list"] = self.inputs.step_list
        self._results["fa_path"] = fa_file_tmp_path
        self._results["traversal"] = self.inputs.traversal
        self._results["labels_im_file"] = labels_im_file_tmp_path
        self._results["min_length"] = self.inputs.min_length

        tmp_files = [B0_mask_tmp_path, dwi_file_tmp_path]

        for j in tmp_files:
            if j is not None:
                if os.path.isfile(j):
                    os.system(f"rm -f {j} &")

        # Exercise caution when deleting copied recon_path
        # if recon_path is not None:
        #     if os.path.isfile(recon_path):
        #         os.remove(recon_path)

        return runtime
示例#7
0
def direct_streamline_norm(streams,
                           fa_path,
                           ap_path,
                           dir_path,
                           track_type,
                           conn_model,
                           subnet,
                           node_radius,
                           dens_thresh,
                           ID,
                           roi,
                           min_span_tree,
                           disp_filt,
                           parc,
                           prune,
                           atlas,
                           labels_im_file,
                           parcellation,
                           labels,
                           coords,
                           norm,
                           binary,
                           atlas_t1w,
                           basedir_path,
                           curv_thr_list,
                           step_list,
                           traversal,
                           min_length,
                           t1w_brain,
                           run_dsn=False):
    """
    A Function to perform normalization of streamlines tracked in native
    diffusion space to an MNI-space template.

    Parameters
    ----------
    streams : str
        File path to save streamline array sequence in .trk format.
    fa_path : str
        File path to FA Nifti1Image.
    ap_path : str
        File path to the anisotropic power Nifti1Image.
    dir_path : str
        Path to directory containing subject derivative data for a given
        pynets run.
    track_type : str
        Tracking algorithm used (e.g. 'local' or 'particle').
    conn_model : str
        Connectivity reconstruction method (e.g. 'csa', 'tensor', 'csd').
    subnet : str
        Resting-state subnet based on Yeo-7 and Yeo-17 naming (e.g. 'Default')
        used to filter nodes in the study of brain subgraphs.
    node_radius : int
        Spherical centroid node size in the case that coordinate-based
        centroids are used as ROI's for tracking.
    dens_thresh : bool
        Indicates whether a target graph density is to be used as the basis for
        thresholding.
    ID : str
        A subject id or other unique identifier.
    roi : str
        File path to binarized/boolean region-of-interest Nifti1Image file.
    min_span_tree : bool
        Indicates whether local thresholding from the Minimum Spanning Tree
        should be used.
    disp_filt : bool
        Indicates whether local thresholding using a disparity filter and
        'backbone subnet' should be used.
    parc : bool
        Indicates whether to use parcels instead of coordinates as ROI nodes.
    prune : bool
        Indicates whether to prune final graph of disconnected nodes/isolates.
    atlas : str
        Name of atlas parcellation used.
    labels_im_file : str
        File path to atlas parcellation Nifti1Image aligned to dwi space.
    parcellation : str
        File path to atlas parcellation Nifti1Image in MNI template space.
    labels : list
        List of string labels corresponding to graph nodes.
    coords : list
        List of (x, y, z) tuples corresponding to a coordinate atlas used or
        which represent the center-of-mass of each parcellation node.
    norm : int
        Indicates method of normalizing resulting graph.
    binary : bool
        Indicates whether to binarize resulting graph edges to form an
        unweighted graph.
    atlas_t1w : str
        File path to atlas parcellation Nifti1Image in T1w-conformed space.
    basedir_path : str
        Path to directory to output direct-streamline normalized temp files
        and outputs.
    curv_thr_list : list
        List of integer curvature thresholds used to perform ensemble tracking.
    step_list : list
        List of float step-sizes used to perform ensemble tracking.
    traversal : str
        The statistical approach to tracking. Options are: det (deterministic),
        closest (clos), boot (bootstrapped), and prob (probabilistic).
    min_length : int
        Minimum fiber length threshold in mm to restrict tracking.
    t1w_brain : str
        File path to the T1w Nifti1Image.

    Returns
    -------
    streams_warp : str
        File path to normalized streamline array sequence in .trk format.
    dir_path : str
        Path to directory containing subject derivative data for a given
        pynets run.
    track_type : str
        Tracking algorithm used (e.g. 'local' or 'particle').
    conn_model : str
        Connectivity reconstruction method (e.g. 'csa', 'tensor', 'csd').
    subnet : str
        Resting-state subnet based on Yeo-7 and Yeo-17 naming (e.g. 'Default')
        used to filter nodes in the study of brain subgraphs.
    node_radius : int
        Spherical centroid node size in the case that coordinate-based
        centroids are used as ROI's for tracking.
    dens_thresh : bool
        Indicates whether a target graph density is to be used as the basis for
        thresholding.
    ID : str
        A subject id or other unique identifier.
    roi : str
        File path to binarized/boolean region-of-interest Nifti1Image file.
    min_span_tree : bool
        Indicates whether local thresholding from the Minimum Spanning Tree
        should be used.
    disp_filt : bool
        Indicates whether local thresholding using a disparity filter and
        'backbone subnet' should be used.
    parc : bool
        Indicates whether to use parcels instead of coordinates as ROI nodes.
    prune : bool
        Indicates whether to prune final graph of disconnected nodes/isolates.
    atlas : str
        Name of atlas parcellation used.
    parcellation : str
        File path to atlas parcellation Nifti1Image in MNI template space.
    labels : list
        List of string labels corresponding to graph nodes.
    coords : list
        List of (x, y, z) tuples corresponding to a coordinate atlas used or
        which represent the center-of-mass of each parcellation node.
    norm : int
        Indicates method of normalizing resulting graph.
    binary : bool
        Indicates whether to binarize resulting graph edges to form an
        unweighted graph.
    atlas_for_streams : str
        File path to atlas parcellation Nifti1Image in the same
        morphological space as the streamlines.
    traversal : str
        The statistical approach to tracking. Options are: det
        (deterministic), closest (clos), boot (bootstrapped),
        and prob (probabilistic).
    warped_fa : str
        File path to MNI-space warped FA Nifti1Image.
    min_length : int
        Minimum fiber length threshold in mm to restrict tracking.

    References
    ----------
    .. [1] Greene, C., Cieslak, M., & Grafton, S. T. (2017). Effect of
      different spatial normalization approaches on tractography and structural
      brain subnets. subnet Neuroscience, 1-19.
    """
    import gc
    from dipy.tracking.streamline import transform_streamlines
    from pynets.registration import utils as regutils
    from pynets.plotting.brain import show_template_bundles
    import os.path as op
    from dipy.io.streamline import load_tractogram
    from dipy.tracking._utils import _mapping_to_voxel
    from dipy.tracking.utils import density_map
    from dipy.io.stateful_tractogram import Space, StatefulTractogram, Origin
    from dipy.io.streamline import save_tractogram

    if run_dsn is True:
        dsn_dir = f"{basedir_path}/dmri_reg/DSN"
        if not op.isdir(dsn_dir):
            os.mkdir(dsn_dir)

        namer_dir = f"{dir_path}/tractography"
        if not op.isdir(namer_dir):
            os.mkdir(namer_dir)

        atlas_img = nib.load(labels_im_file)

        # Run SyN and normalize streamlines
        fa_img = nib.load(fa_path)

        atlas_t1w_img = nib.load(atlas_t1w)
        t1w_brain_img = nib.load(t1w_brain)
        brain_mask = np.asarray(t1w_brain_img.dataobj).astype("bool")

        streams_t1w = "%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s" % (
            namer_dir,
            "/streamlines_t1w_",
            "%s" % (subnet + "_" if subnet is not None else ""),
            "%s" %
            (op.basename(roi).split(".")[0] + "_" if roi is not None else ""),
            conn_model,
            "%s" % ("%s%s" % ("_" + str(node_radius), "mm_") if
                    ((node_radius != "parc") and
                     (node_radius is not None)) else "_"),
            "curv",
            str(curv_thr_list).replace(", ", "_"),
            "step",
            str(step_list).replace(", ", "_"),
            "tracktype-",
            track_type,
            "_traversal-",
            traversal,
            "_minlength-",
            min_length,
            ".trk",
        )

        density_t1w = "%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s" % (
            namer_dir,
            "/density_map_t1w_",
            "%s" % (subnet + "_" if subnet is not None else ""),
            "%s" %
            (op.basename(roi).split(".")[0] + "_" if roi is not None else ""),
            conn_model,
            "%s" % ("%s%s" % ("_" + str(node_radius), "mm_") if
                    ((node_radius != "parc") and
                     (node_radius is not None)) else "_"),
            "curv",
            str(curv_thr_list).replace(", ", "_"),
            "step",
            str(step_list).replace(", ", "_"),
            "tracktype-",
            track_type,
            "_traversal-",
            traversal,
            "_minlength-",
            min_length,
            ".nii.gz",
        )

        streams_warp_png = '/tmp/dsn.png'

        # SyN FA->Template
        [mapping, affine_map,
         warped_fa] = regutils.wm_syn(t1w_brain, ap_path, dsn_dir)

        tractogram = load_tractogram(
            streams,
            fa_img,
            to_origin=Origin.NIFTI,
            to_space=Space.VOXMM,
            bbox_valid_check=False,
        )

        fa_img.uncache()
        streamlines = tractogram.streamlines
        warped_fa_img = nib.load(warped_fa)
        warped_fa_affine = warped_fa_img.affine
        warped_fa_shape = warped_fa_img.shape

        streams_in_curr_grid = transform_streamlines(streamlines,
                                                     affine_map.affine_inv)

        streams_final_filt = regutils.warp_streamlines(t1w_brain_img.affine,
                                                       fa_img.affine, mapping,
                                                       warped_fa_img,
                                                       streams_in_curr_grid,
                                                       brain_mask)

        # Remove streamlines with negative voxel indices
        lin_T, offset = _mapping_to_voxel(np.eye(4))
        streams_final_filt_final = []
        for sl in streams_final_filt:
            inds = np.dot(sl, lin_T)
            inds += offset
            if not inds.min().round(decimals=6) < 0:
                streams_final_filt_final.append(sl)

        # Save streamlines
        stf = StatefulTractogram(
            streams_final_filt_final,
            reference=t1w_brain_img,
            space=Space.VOXMM,
            origin=Origin.NIFTI,
        )
        stf.remove_invalid_streamlines()
        streams_final_filt_final = stf.streamlines
        save_tractogram(stf, streams_t1w, bbox_valid_check=True)
        warped_fa_img.uncache()

        # DSN QC plotting
        show_template_bundles(streams_final_filt_final, atlas_t1w,
                              streams_warp_png)

        nib.save(
            nib.Nifti1Image(
                density_map(streams_final_filt_final,
                            affine=np.eye(4),
                            vol_dims=warped_fa_shape),
                warped_fa_affine,
            ),
            density_t1w,
        )

        del (
            tractogram,
            streamlines,
            stf,
            streams_final_filt_final,
            streams_final_filt,
            streams_in_curr_grid,
            brain_mask,
        )

        gc.collect()

        assert len(coords) == len(labels)

        atlas_for_streams = atlas_t1w

    else:
        print(
            "Skipping Direct Streamline Normalization (DSN). Will proceed to "
            "define fiber connectivity in native diffusion space...")
        streams_t1w = streams
        warped_fa = fa_path
        atlas_for_streams = labels_im_file

    return (streams_t1w, dir_path, track_type, conn_model, subnet, node_radius,
            dens_thresh, ID, roi, min_span_tree, disp_filt, parc, prune, atlas,
            parcellation, labels, coords, norm, binary, atlas_for_streams,
            traversal, warped_fa, min_length)
示例#8
0
文件: conftest.py 项目: dPys/PyNets
def tractography_estimation_data(dmri_estimation_data):
    path_tmp = tempfile.NamedTemporaryFile(mode='w+',
                                           suffix='.trk',
                                           delete=False)
    trk_path_tmp = str(path_tmp.name)
    dir_path = os.path.dirname(trk_path_tmp)

    gtab = dmri_estimation_data['gtab']
    wm_img = nib.load(dmri_estimation_data['f_pve_wm'])
    dwi_img = nib.load(dmri_estimation_data['dwi_file'])
    dwi_data = dwi_img.get_fdata()
    B0_mask_img = nib.load(dmri_estimation_data['B0_mask'])
    mask_img = intersect_masks(
        [
            nib.Nifti1Image(np.asarray(
                wm_img.dataobj).astype('bool').astype('int'),
                            affine=wm_img.affine),
            nib.Nifti1Image(np.asarray(
                B0_mask_img.dataobj).astype('bool').astype('int'),
                            affine=B0_mask_img.affine)
        ],
        threshold=1,
        connected=False,
    )

    mask_data = mask_img.get_fdata()
    mask_file = fname_presuffix(dmri_estimation_data['B0_mask'],
                                suffix="tracking_mask",
                                use_ext=True)
    mask_img.to_filename(mask_file)
    csa_model = CsaOdfModel(gtab, sh_order=6)
    csa_peaks = peaks_from_model(csa_model,
                                 dwi_data,
                                 default_sphere,
                                 relative_peak_threshold=.8,
                                 min_separation_angle=45,
                                 mask=mask_data)

    stopping_criterion = BinaryStoppingCriterion(mask_data)

    seed_mask = (mask_data == 1)
    seeds = seeds_from_mask(seed_mask, dwi_img.affine, density=[1, 1, 1])

    streamlines_generator = LocalTracking(csa_peaks,
                                          stopping_criterion,
                                          seeds,
                                          affine=dwi_img.affine,
                                          step_size=.5)
    streamlines = Streamlines(streamlines_generator)
    sft = StatefulTractogram(streamlines,
                             B0_mask_img,
                             origin=Origin.NIFTI,
                             space=Space.VOXMM)
    sft.remove_invalid_streamlines()
    trk = f"{dir_path}/tractogram.trk"
    os.rename(trk_path_tmp, trk)
    save_tractogram(sft, trk, bbox_valid_check=False)
    del streamlines, sft, streamlines_generator, seeds, seed_mask, csa_peaks, \
        csa_model, dwi_data, mask_data
    dwi_img.uncache()
    mask_img.uncache()
    gc.collect()

    yield {'trk': trk, 'mask': mask_file}
            track_moving_warped = np.zeros([n_, N_points, 3])
            for idx in range(n_):
                track_moving_warped[idx] = moving_warped[idx *
                                                         N_points:N_points *
                                                         (idx + 1)]
        else:
            track_moving_warped = track_moving.copy()
            for i, streamline in enumerate(track_moving):
                streamline_warp = warpNeigh.predict(streamline)
                track_moving_warped[i] += streamline_warp

        warped_filename = WarpedShot_dir + '/track_warped' + suffix + '.trk'

        track_moving_warped_sft = StatefulTractogram(track_moving_warped,
                                                     FA_nib, Space.RASMM)
        idx_toremove, idx_tokeep = track_moving_warped_sft.remove_invalid_streamlines(
        )
        save_tractogram(track_moving_warped_sft, warped_filename)
        print("save warped tracts as: " + warped_filename)

        #show_both_bundles((track_moving_warped,track_fixed,track_moving),colors=[window.colors.cyan,window.colors.green,window.colors.red],fname=ScreenShot_dir+'/after_Warp.png')
#%% PLOT
if plot_flag:
    import matplotlib.pyplot as plt
    #from mpl_toolkits.mplot3d import Axes3D

    fig = plt.figure()
    #ax = fig.gca(projection='3d')

    it_arr = range(100000)

    #ax = fig.gca()