Example #1
0
def reference_info_zero_affine():
    header = create_nifti_header(np.zeros((4, 4)), [10, 10, 10], [1, 1, 1])
    try:
        get_reference_info(header)
        return True
    except ValueError:
        return False
Example #2
0
def compute_distance_barycenters(ref_1, ref_2, ref_2_transfo):
    """
    Compare the barycenter (center of volume) of two reference object.
    The provided transformation will move the reference #2 and
    return the distance before and after transformation.

    Parameters
    ----------
    ref_1: reference object
        Any type supported by the sft as reference (e.g .nii of .trk).
    ref_2: reference object
        Any type supported by the sft as reference (e.g .nii of .trk).
    ref_2_transfo: np.ndarray
        Transformation that modifies the barycenter of ref_2.
    Returns
    -------
    distance: float or tuple (2,)
        return a tuple containing the distance before and after
        the transformation.
    """
    aff_1, dim_1, _, _ = get_reference_info(ref_1)
    aff_2, dim_2, _, _ = get_reference_info(ref_2)

    barycenter_1 = voxel_to_world(dim_1 / 2.0, aff_1)
    barycenter_2 = voxel_to_world(dim_2 / 2.0, aff_2)
    distance_before = np.linalg.norm(barycenter_1 - barycenter_2)

    normalized_coord = row[barycenter_2[0:3], 1.0].astype(float)
    barycenter_2 = np.dot(ref_2_transfo, normalized_coord)[0:3]

    distance_after = np.linalg.norm(barycenter_1 - barycenter_2)

    return distance_before, distance_after
Example #3
0
def test_reference_info_identical():
    tuple_1 = get_reference_info(filepath_dix['gs.trk'])
    tuple_2 = get_reference_info(filepath_dix['gs.nii'])
    affine_1, dimensions_1, voxel_sizes_1, voxel_order_1 = tuple_1
    affine_2, dimensions_2, voxel_sizes_2, voxel_order_2 = tuple_2

    assert_allclose(affine_1, affine_2)
    assert_array_equal(dimensions_1, dimensions_2)
    assert_allclose(voxel_sizes_1, voxel_sizes_2)
    assert voxel_order_1 == voxel_order_2
Example #4
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_bundles)
    output_streamlines_filename = '{}streamlines.trk'.format(
        args.output_prefix)
    output_voxels_filename = '{}voxels.nii.gz'.format(args.output_prefix)
    assert_outputs_exist(parser, args,
                         [output_voxels_filename, output_streamlines_filename])

    if not 0 <= args.ratio_voxels <= 1 or not 0 <= args.ratio_streamlines <= 1:
        parser.error('Ratios must be between 0 and 1.')

    fusion_streamlines = []
    if args.reference:
        reference_file = args.reference
    else:
        reference_file = args.in_bundles[0]
    sft_list = []
    for name in args.in_bundles:
        tmp_sft = load_tractogram_with_reference(parser, args, name)
        tmp_sft.to_vox()
        tmp_sft.to_corner()

        if not is_header_compatible(reference_file, tmp_sft):
            raise ValueError('Headers are not compatible.')
        sft_list.append(tmp_sft)
        fusion_streamlines.append(tmp_sft.streamlines)

    fusion_streamlines, _ = union_robust(fusion_streamlines)

    transformation, dimensions, _, _ = get_reference_info(reference_file)
    volume = np.zeros(dimensions)
    streamlines_vote = dok_matrix(
        (len(fusion_streamlines), len(args.in_bundles)))

    for i in range(len(args.in_bundles)):
        sft = sft_list[i]
        binary = compute_tract_counts_map(sft.streamlines, dimensions)
        volume[binary > 0] += 1

        if args.same_tractogram:
            _, indices = intersection_robust(
                [fusion_streamlines, sft.streamlines])
            streamlines_vote[list(indices), [i]] += 1

    if args.same_tractogram:
        real_indices = []
        ratio_value = int(args.ratio_streamlines * len(args.in_bundles))
        real_indices = np.where(
            np.sum(streamlines_vote, axis=1) >= ratio_value)[0]
        new_sft = StatefulTractogram.from_sft(fusion_streamlines[real_indices],
                                              sft_list[0])
        save_tractogram(new_sft, output_streamlines_filename)

    volume[volume < int(args.ratio_voxels * len(args.in_bundles))] = 0
    volume[volume > 0] = 1
    nib.save(nib.Nifti1Image(volume.astype(np.uint8), transformation),
             output_voxels_filename)
Example #5
0
def get_seeds_from_wm(wm_path, threshold=0):

    wm_file = nib.load(wm_path)
    wm_img = wm_file.get_fdata()

    seeds = np.argwhere(wm_img > threshold)
    seeds = np.hstack([seeds, np.ones([len(seeds), 1])])

    seeds = (wm_file.affine.dot(seeds.T).T)[:, :3].reshape(-1, 1, 3)

    n_seeds = len(seeds)

    header = TrkFile.create_empty_header()

    header["voxel_to_rasmm"] = wm_file.affine
    header["dimensions"] = wm_file.header["dim"][1:4]
    header["voxel_sizes"] = wm_file.header["pixdim"][1:4]
    header["voxel_order"] = get_reference_info(wm_file)[3]

    tractogram = Tractogram(streamlines=ArraySequence(seeds),
                            affine_to_rasmm=np.eye(4))

    save_path = os.path.join(os.path.dirname(wm_path), "seeds_from_wm.trk")

    print("Saving {}".format(save_path))
    TrkFile(tractogram, header).save(save_path)
Example #6
0
def transform_anatomy(transfo, reference, moving, filename_to_save,
                      interp='linear', keep_dtype=False):
    """
    Apply transformation to an image using Dipy's tool

    Parameters
    ----------
    transfo: numpy.ndarray
        Transformation matrix to be applied
    reference: str
        Filename of the reference image (target)
    moving: str
        Filename of the moving image
    filename_to_save: str
        Filename of the output image
    interp : string, either 'linear' or 'nearest'
        the type of interpolation to be used, either 'linear'
        (for k-linear interpolation) or 'nearest' for nearest neighbor
    keep_dtype : bool
        If True, keeps the data_type of the input moving image when saving
        the output image
    """
    grid2world, dim, _, _ = get_reference_info(reference)
    static_data = nib.load(reference).get_fdata(dtype=np.float32)

    nib_file = nib.load(moving)
    curr_type = nib_file.get_data_dtype()
    if keep_dtype:
        moving_data = np.asanyarray(nib_file.dataobj).astype(curr_type)
    else:
        moving_data = nib_file.get_fdata(dtype=np.float32)
    moving_affine = nib_file.affine

    if moving_data.ndim == 3 and isinstance(moving_data[0, 0, 0],
                                            np.ScalarType):
        orig_type = moving_data.dtype
        affine_map = AffineMap(np.linalg.inv(transfo),
                               dim, grid2world,
                               moving_data.shape, moving_affine)
        resampled = affine_map.transform(moving_data.astype(np.float64),
                                         interpolation=interp)
        nib.save(nib.Nifti1Image(resampled.astype(orig_type), grid2world),
                 filename_to_save)
    elif len(moving_data[0, 0, 0]) > 1:
        if isinstance(moving_data[0, 0, 0], np.void):
            raise ValueError('Does not support TrackVis RGB')

        affine_map = AffineMap(np.linalg.inv(transfo),
                               dim[0:3], grid2world,
                               moving_data.shape[0:3], moving_affine)

        orig_type = moving_data.dtype
        resampled = transform_dwi(affine_map, static_data, moving_data,
                                  interpolation=interp)
        nib.save(nib.Nifti1Image(resampled.astype(orig_type), grid2world),
                 filename_to_save)
    else:
        raise ValueError('Does not support this dataset (shape, type, etc)')
Example #7
0
def assert_same_resolution(images):
    """
    Check the resolution of multiple images.
    Parameters
    ----------
    images : array of string or string
        List of images or an image.
    """
    if isinstance(images, str):
        images = [images]

    if len(images) == 0:
        raise Exception("Can't check if images are of the same "
                        "resolution/affine. No image has been given")

    aff_1, shape_1, _, _ = get_reference_info(images[0])
    for i in images[1:]:
        aff_2, shape_2, _, _ = get_reference_info(i)
        if not (shape_1 == shape_2) and (aff_1 == aff_2).any():
            raise Exception("Images are not of the same resolution/affine")
    def __init__(self, nb_vertices=None, nb_streamlines=None, init_as=None,
                 reference=None):
        """ Initialize an empty TrxFile, support preallocation """
        if init_as is not None:
            affine = init_as.header['VOXEL_TO_RASMM']
            dimensions = init_as.header['DIMENSIONS']
        elif reference is not None:
            affine, dimensions, _, _ = get_reference_info(reference)
        else:
            logging.debug('No reference provided, using blank space '
                          'attributes, please update them later.')
            affine = np.eye(4).astype(np.float32)
            dimensions = np.array([1, 1, 1], dtype=np.uint16)

        if nb_vertices is None and nb_streamlines is None:
            if init_as is not None:
                raise ValueError('Cant use init_as without declaring '
                                 'nb_vertices AND nb_streamlines')
            logging.debug('Intializing empty TrxFile.')
            self.header = {}
            # Using the new format default type
            tmp_strs = ArraySequence()
            tmp_strs._data = tmp_strs._data.astype(np.float16)
            tmp_strs._offsets = tmp_strs._offsets.astype(np.uint64)
            tmp_strs._lengths = tmp_strs._lengths.astype(np.uint32)
            self.streamlines = tmp_strs
            self.groups = {}
            self.data_per_streamline = {}
            self.data_per_vertex = {}
            self.data_per_group = {}
            self._uncompressed_folder_handle = None

            nb_vertices = 0
            nb_streamlines = 0

        elif nb_vertices is not None and nb_streamlines is not None:
            logging.debug('Preallocating TrxFile with size {} streamlines'
                          'and {} vertices.'.format(nb_streamlines, nb_vertices))
            trx = self._initialize_empty_trx(nb_streamlines, nb_vertices,
                                             init_as=init_as)
            self.__dict__ = trx.__dict__
        else:
            raise ValueError('You must declare both nb_vertices AND '
                             'NB_STREAMLINES')

        self.header['VOXEL_TO_RASMM'] = affine
        self.header['DIMENSIONS'] = dimensions
        self.header['NB_VERTICES'] = nb_vertices
        self.header['NB_STREAMLINES'] = nb_streamlines
        self._copy_safe = True
Example #9
0
def to_bin_mask(tractogram):
    affine, dimensions, voxel_sizes, voxel_order = get_reference_info(
        reference_anatomy)
    low_res = (2.5, 2.5, 2.5)

    mask = np.zeros(dimensions, dtype=np.float)
    tractogram.to_vox()
    try:
        streams = np.vstack(tractogram.streamlines).astype(np.int32)
        mask[streams[:, 0], streams[:, 1], streams[:, 2]] = 1
    except:
        print('Trivial mask!')

    return mask, affine
Example #10
0
def density_map(tractogram, n_sls=None, to_vox=False, normalize=False):
    """
    Create a streamline density map.
    based on:
    https://dipy.org/documentation/1.1.1./examples_built/streamline_formats/

    Parameters
    ----------
    tractogram : StatefulTractogram
        Stateful tractogram whose streamlines are used to make
        the density map.
    n_sls : int or None, optional
        n_sls to randomly select to make the density map.
        If None, all streamlines are used.
        Default: None
    to_vox : bool, optional
        Whether to put the stateful tractogram in VOX space before making
        the density map.
        Default: False
    normalize : bool, optional
        Whether to normalize maximum values to 1.
        Default: False

    Returns
    -------
    Nifti1Image containing the density map.
    """
    if to_vox:
        tractogram.to_vox()

    sls = tractogram.streamlines
    if n_sls is not None:
        sls = select_random_set_of_streamlines(sls, n_sls)

    affine, vol_dims, voxel_sizes, voxel_order = get_reference_info(tractogram)
    tractogram_density = dtu.density_map(sls, np.eye(4), vol_dims)
    if normalize:
        tractogram_density = tractogram_density / tractogram_density.max()

    nifti_header = create_nifti_header(affine, vol_dims, voxel_sizes)
    density_map_img = nib.Nifti1Image(tractogram_density, affine, nifti_header)

    return density_map_img
def load_node_nifti(directory, in_label, out_label, ref_filename):
    in_filename_1 = os.path.join(directory,
                                 '{}_{}.nii.gz'.format(in_label, out_label))
    in_filename_2 = os.path.join(directory,
                                 '{}_{}.nii.gz'.format(out_label, in_label))
    in_filename = None
    if os.path.isfile(in_filename_1):
        in_filename = in_filename_1
    elif os.path.isfile(in_filename_2):
        in_filename = in_filename_2

    if in_filename is not None:
        if not is_header_compatible(in_filename, ref_filename):
            logging.error('{} and {} do not have a compatible header'.format(
                in_filename, ref_filename))
            raise IOError
        return nib.load(in_filename).get_fdata()

    _, dims, _, _ = get_reference_info(ref_filename)
    return np.zeros(dims)
    def from_tractogram(tractogram, reference, cast_position=np.float16):
        """ Generate a valid TrxFile from a Nibabel Tractogram """
        if not np.issubdtype(cast_position, np.floating):
            logging.warning('Casting as {}, considering using a floating point '
                            'dtype.'.format(cast_position))

        trx = TrxFile(nb_vertices=len(tractogram.streamlines._data),
                      nb_streamlines=len(tractogram.streamlines))

        affine, dimensions, _, _ = get_reference_info(reference)
        trx.header = {'DIMENSIONS': dimensions,
                      'VOXEL_TO_RASMM': affine,
                      'NB_VERTICES': len(tractogram.streamlines._data),
                      'NB_STREAMLINES': len(tractogram.streamlines)}

        if cast_position != np.float32:
            tmp_streamlines = deepcopy(tractogram.streamlines)
        else:
            tmp_streamlines = tractogram.streamlines

        # Cast the int64 of Nibabel to uint64
        tmp_streamlines._offsets = tmp_streamlines._offsets.astype(np.uint64)
        if cast_position != np.float32:
            tmp_streamlines._data = tmp_streamlines._data.astype(cast_position)

        trx.streamlines = tmp_streamlines
        trx.data_per_streamline = tractogram.data_per_streamline
        trx.data_per_vertex = tractogram.data_per_point

        # For safety and for RAM, convert the whole object to memmaps
        tmpdir = tempfile.TemporaryDirectory()
        save(trx, tmpdir.name)
        trx = load_from_directory(tmpdir.name)
        trx._uncompressed_folder_handle = tmpdir

        return trx
Example #13
0
def get_ismrm_seeds(data_dir, source, keep, weighted, threshold, voxel):

    trk_dir = os.path.join(data_dir, "bundles")

    if source in ["wm", "trk"]:
        anat_path = os.path.join(data_dir, "masks", "wm.nii.gz")
        resized_path = os.path.join(data_dir, "masks",
                                    "wm_{}.nii.gz".format(voxel))
    elif source == "brain":
        anat_path = os.path.join("subjects", "ismrm_gt",
                                 "dwi_brain_mask.nii.gz")
        resized_path = os.path.join("subjects", "ismrm_gt",
                                    "dwi_brain_mask_125.nii.gz")

    sp.call([
        "mrresize", "-voxel", "{:1.2f}".format(voxel / 100), anat_path,
        resized_path
    ])

    if source == "trk":

        print("Running Tractconverter...")
        sp.call([
            "python", "tractconverter/scripts/WalkingTractConverter.py", "-i",
            trk_dir, "-a", resized_path, "-vtk2trk"
        ])

        print("Loading seed bundles...")
        seed_bundles = []
        for i, trk_path in enumerate(glob.glob(os.path.join(trk_dir,
                                                            "*.trk"))):
            trk_file = nib.streamlines.load(trk_path)
            endpoints = []
            for fiber in trk_file.tractogram.streamlines:
                endpoints.append(fiber[0])
                endpoints.append(fiber[-1])
            seed_bundles.append(endpoints)
            if i == 0:
                header = trk_file.header

        n_seeds = sum([len(b) for b in seed_bundles])
        n_bundles = len(seed_bundles)

        print("Loaded {} seeds from {} bundles.".format(n_seeds, n_bundles))

        seeds = np.array([[seed] for bundle in seed_bundles
                          for seed in bundle])

        if keep < 1:
            if weighted:
                p = np.zeros(n_seeds)
                offset = 0
                for b in seed_bundles:
                    l = len(b)
                    p[offset:offset + l] = 1 / (l * n_bundles)
                    offset += l
            else:
                p = np.ones(n_seeds) / n_seeds

    elif source in ["brain", "wm"]:

        weighted = False

        wm_file = nib.load(resized_path)
        wm_img = wm_file.get_fdata()

        seeds = np.argwhere(wm_img > threshold)
        seeds = np.hstack([seeds, np.ones([len(seeds), 1])])

        seeds = (wm_file.affine.dot(seeds.T).T)[:, :3].reshape(-1, 1, 3)

        n_seeds = len(seeds)

        if keep < 1:
            p = np.ones(n_seeds) / n_seeds

        header = TrkFile.create_empty_header()

        header["voxel_to_rasmm"] = wm_file.affine
        header["dimensions"] = wm_file.header["dim"][1:4]
        header["voxel_sizes"] = wm_file.header["pixdim"][1:4]
        header["voxel_order"] = get_reference_info(wm_file)[3]

    if keep < 1:
        keep_n = int(keep * n_seeds)
        print("Subsampling from {} seeds to {} seeds".format(n_seeds, keep_n))
        np.random.seed(42)
        keep_idx = np.random.choice(len(seeds),
                                    size=keep_n,
                                    replace=False,
                                    p=p)
        seeds = seeds[keep_idx]

    tractogram = Tractogram(streamlines=ArraySequence(seeds),
                            affine_to_rasmm=np.eye(4))

    save_dir = os.path.join(data_dir, "seeds")
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    save_path = os.path.join(save_dir, "seeds_from_{}_{}_vox{:03d}.trk")
    save_path = save_path.format(
        source, "W" + str(int(100 * keep)) if weighted else "all", voxel)

    print("Saving {}".format(save_path))
    TrkFile(tractogram, header).save(save_path)

    os.remove(resized_path)
    for file in glob.glob(os.path.join(trk_dir, "*.trk")):
        os.remove(file)
Example #14
0
def main():
    parser = _build_args_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_bundles)
    output_streamlines_filename = '{}streamlines.trk'.format(
        args.output_prefix)
    output_voxels_filename = '{}voxels.nii.gz'.format(args.output_prefix)
    assert_outputs_exist(parser, args,
                         [output_voxels_filename, output_streamlines_filename])

    if not 0 <= args.ratio_voxels <= 1 or not 0 <= args.ratio_streamlines <= 1:
        parser.error('Ratios must be between 0 and 1.')

    fusion_streamlines = []
    for name in args.in_bundles:
        fusion_streamlines.extend(
            load_tractogram_with_reference(parser, args, name).streamlines)

    fusion_streamlines, _ = perform_streamlines_operation(
        union, [fusion_streamlines], 0)
    fusion_streamlines = ArraySequence(fusion_streamlines)
    if args.reference:
        reference_file = args.reference
    else:
        reference_file = args.in_bundles[0]

    transformation, dimensions, _, _ = get_reference_info(reference_file)
    volume = np.zeros(dimensions)
    streamlines_vote = dok_matrix(
        (len(fusion_streamlines), len(args.in_bundles)))

    for i, name in enumerate(args.in_bundles):
        if not is_header_compatible(reference_file, name):
            raise ValueError('Both headers are not the same')
        sft = load_tractogram_with_reference(parser, args, name)
        bundle = sft.get_streamlines_copy()
        sft.to_vox()
        bundle_vox_space = sft.get_streamlines_copy()
        binary = compute_tract_counts_map(bundle_vox_space, dimensions)
        volume[binary > 0] += 1

        if args.same_tractogram:
            _, indices = perform_streamlines_operation(
                intersection, [fusion_streamlines, bundle], 0)
            streamlines_vote[list(indices), i] += 1

    if args.same_tractogram:
        real_indices = []
        for i in range(len(fusion_streamlines)):
            ratio_value = int(args.ratio_streamlines * len(args.in_bundles))
            if np.sum(streamlines_vote[i]) >= ratio_value:
                real_indices.append(i)

        new_streamlines = fusion_streamlines[real_indices]

        sft = StatefulTractogram(new_streamlines, reference_file, Space.RASMM)
        save_tractogram(sft, output_streamlines_filename)

    volume[volume < int(args.ratio_streamlines * len(args.in_bundles))] = 0
    volume[volume > 0] = 1
    nib.save(nib.Nifti1Image(volume.astype(np.uint8), transformation),
             output_voxels_filename)
Example #15
0
"""
The function ``load_tractogram`` requires a reference, any of the following
inputs is considered valid (as long as they are in the same share space)
- Nifti filename
- Trk filename
- nib.nifti1.Nifti1Image
- nib.streamlines.trk.TrkFile
- nib.nifti1.Nifti1Header
- Trk header (dict)
- Stateful Tractogram

The reason why this parameter is required is to guarantee all information
related to space attributes is always present.
"""

affine, dimensions, voxel_sizes, voxel_order = get_reference_info(
    reference_anatomy)
print(affine)
print(dimensions)
print(voxel_sizes)
print(voxel_order)
"""
If you have a Trk file that was generated using a particular anatomy,
to be considered valid all fields must correspond between the headers.
It can be easily verified using this function, which also accept
the same variety of input as ``get_reference_info``
"""

print(is_header_compatible(reference_anatomy, bundles_filename[0]))
"""
If a TRK was generated with a valid header, but the reference NIFTI was lost
a header can be generated to then generate a fake NIFTI file.
Example #16
0
    def __init__(self, streamlines, reference, space,
                 origin=Origin.NIFTI,
                 data_per_point=None, data_per_streamline=None):
        """ Create a strict, state-aware, robust tractogram

        Parameters
        ----------
        streamlines : list or ArraySequence
            Streamlines of the tractogram
        reference : Nifti or Trk filename, Nifti1Image or TrkFile,
            Nifti1Header, trk.header (dict) or another Stateful Tractogram
            Reference that provides the spatial attributes.
            Typically a nifti-related object from the native diffusion used for
            streamlines generation
        space : Enum (dipy.io.stateful_tractogram.Space)
            Current space in which the streamlines are (vox, voxmm or rasmm)
            After tracking the space is VOX, after loading with nibabel
            the space is RASMM
        origin : Enum (dipy.io.stateful_tractogram.Origin), optional
            Current origin in which the streamlines are (center or corner)
            After loading with nibabel the origin is CENTER
        data_per_point : dict, optional
            Dictionary in which each key has X items, each items has Y_i items
            X being the number of streamlines
            Y_i being the number of points on streamlines #i
        data_per_streamline : dict, optional
            Dictionary in which each key has X items
            X being the number of streamlines

        Notes
        -----
        Very important to respect the convention, verify that streamlines
        match the reference and are effectively in the right space.

        Any change to the number of streamlines, data_per_point or
        data_per_streamline requires particular verification.

        In a case of manipulation not allowed by this object, use Nibabel
        directly and be careful.
        """
        if data_per_point is None:
            data_per_point = {}

        if data_per_streamline is None:
            data_per_streamline = {}

        if isinstance(streamlines, Streamlines):
            streamlines = streamlines.copy()
        self._tractogram = Tractogram(streamlines,
                                      data_per_point=data_per_point,
                                      data_per_streamline=data_per_streamline)

        if isinstance(reference, type(self)):
            logger.warning('Using a StatefulTractogram as reference, this '
                           'will copy only the space_attributes, not '
                           'the state. The variables space and origin '
                           'must be specified separately.')
            logger.warning('To copy the state from another StatefulTractogram '
                           'you may want to use the function from_sft '
                           '(static function of the StatefulTractogram).')

        if isinstance(reference, tuple) and len(reference) == 4:
            if is_reference_info_valid(*reference):
                space_attributes = reference
            else:
                raise TypeError('The provided space attributes are not '
                                'considered valid, please correct before '
                                'using them with StatefulTractogram.')
        else:
            space_attributes = get_reference_info(reference)
            if space_attributes is None:
                raise TypeError('Reference MUST be one of the following:\n'
                                'Nifti or Trk filename, Nifti1Image or '
                                'TrkFile, Nifti1Header or trk.header (dict).')

        (self._affine, self._dimensions,
         self._voxel_sizes, self._voxel_order) = space_attributes
        self._inv_affine = np.linalg.inv(self._affine)

        if space not in Space:
            raise ValueError('Space MUST be from Space enum, e.g Space.VOX.')
        self._space = space

        if origin not in Origin:
            raise ValueError('Origin MUST be from Origin enum, '
                             'e.g Origin.NIFTI.')
        self._origin = origin
        logger.debug(self)
Example #17
0
def main():
    t_init = perf_counter()
    parser = _build_arg_parser()
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.INFO)

    assert_inputs_exist(parser, [args.in_odf, args.in_mask, args.in_seed])
    assert_outputs_exist(parser, args, args.out_tractogram)
    if args.compress is not None:
        verify_compression_th(args.compress)

    odf_sh_img = nib.load(args.in_odf)
    mask = get_data_as_mask(nib.load(args.in_mask))
    seed_mask = get_data_as_mask(nib.load(args.in_seed))
    odf_sh = odf_sh_img.get_fdata(dtype=np.float32)

    t0 = perf_counter()
    if args.npv:
        nb_seeds = args.npv
        seed_per_vox = True
    elif args.nt:
        nb_seeds = args.nt
        seed_per_vox = False
    else:
        nb_seeds = 1
        seed_per_vox = True

    # Seeds are returned with origin `center`.
    # However, GPUTracker expects origin to be `corner`.
    # Therefore, we need to shift the seed positions by half voxel.
    seeds = random_seeds_from_mask(seed_mask,
                                   np.eye(4),
                                   seeds_count=nb_seeds,
                                   seed_count_per_voxel=seed_per_vox,
                                   random_seed=args.rng_seed) + 0.5
    logging.info('Generated {0} seed positions in {1:.2f}s.'.format(
        len(seeds),
        perf_counter() - t0))

    voxel_size = odf_sh_img.header.get_zooms()[0]
    vox_step_size = args.step_size / voxel_size
    vox_max_length = args.max_length / voxel_size
    vox_min_length = args.min_length / voxel_size
    min_strl_len = int(vox_min_length / vox_step_size) + 1
    max_strl_len = int(vox_max_length / vox_step_size) + 1

    # initialize tracking
    tracker = GPUTacker(odf_sh, mask, seeds, vox_step_size, min_strl_len,
                        max_strl_len, args.theta, args.sh_basis,
                        args.batch_size, args.forward_only, args.rng_seed)

    # wrapper for tracker.track() yielding one TractogramItem per
    # streamline for use with the LazyTractogram.
    def tracks_generator_wrapper():
        for strl, seed in tracker.track():
            # seed must be saved in voxel space, with origin `center`.
            dps = {'seeds': seed - 0.5} if args.save_seeds else {}

            # TODO: Investigate why the streamline must NOT be shifted to
            # origin `corner` for LazyTractogram.
            strl *= voxel_size  # in mm.
            if args.compress:
                strl = compress_streamlines(strl, args.compress)
            yield TractogramItem(strl, dps, {})

    # instantiate tractogram
    tractogram = LazyTractogram.from_data_func(tracks_generator_wrapper)
    tractogram.affine_to_rasmm = odf_sh_img.affine

    filetype = nib.streamlines.detect_format(args.out_tractogram)
    reference = get_reference_info(odf_sh_img)
    header = create_tractogram_header(filetype, *reference)

    # Use generator to save the streamlines on-the-fly
    nib.streamlines.save(tractogram, args.out_tractogram, header=header)
    logging.info('Saved tractogram to {0}.'.format(args.out_tractogram))

    # Total runtime
    logging.info('Total runtime of {0:.2f}s.'.format(perf_counter() - t_init))
Example #18
0
    def __init__(self,
                 streamlines,
                 reference,
                 space,
                 shifted_origin=False,
                 data_per_point=None,
                 data_per_streamline=None):
        """ Create a strict, state-aware, robust tractogram

        Parameters
        ----------
        streamlines : list or ArraySequence
            Streamlines of the tractogram
        reference : Nifti or Trk filename, Nifti1Image or TrkFile,
            Nifti1Header, trk.header (dict) or another Stateful Tractogram
            Reference that provides the spatial attributes.
            Typically a nifti-related object from the native diffusion used for
            streamlines generation
        space : Enum (dipy.io.stateful_tractogram.Space)
            Current space in which the streamlines are (vox, voxmm or rasmm)
            Typically after tracking the space is VOX, after nibabel loading
            the space is RASMM
        shifted_origin : bool
            Information on the position of the origin,
            False is Trackvis standard, default (center of the voxel)
            True is NIFTI standard (corner of the voxel)
        data_per_point : dict
            Dictionary in which each key has X items, each items has Y_i items
            X being the number of streamlines
            Y_i being the number of points on streamlines #i
        data_per_streamline : dict
            Dictionary in which each key has X items
            X being the number of streamlines

        Notes
        -----
        Very important to respect the convention, verify that streamlines
        match the reference and are effectively in the right space.

        Any change to the number of streamlines, data_per_point or
        data_per_streamline requires particular verification.

        In a case of manipulation not allowed by this object, use Nibabel
        directly and be careful.
        """
        if data_per_point is None:
            data_per_point = {}

        if data_per_streamline is None:
            data_per_streamline = {}

        if isinstance(streamlines, Streamlines):
            streamlines = streamlines.copy()
        self._tractogram = Tractogram(streamlines,
                                      data_per_point=data_per_point,
                                      data_per_streamline=data_per_streamline)

        space_attributes = get_reference_info(reference)
        if space_attributes is None:
            raise TypeError('Reference MUST be one of the following:\n' +
                            'Nifti or Trk filename, Nifti1Image or TrkFile, ' +
                            'Nifti1Header or trk.header (dict)')

        (self._affine, self._dimensions, self._voxel_sizes,
         self._voxel_order) = space_attributes
        self._inv_affine = np.linalg.inv(self._affine)

        if space not in Space:
            raise ValueError('Space MUST be from Space enum, e.g Space.VOX')
        self._space = space

        if not isinstance(shifted_origin, bool):
            raise TypeError('shifted_origin MUST be a boolean')
        self._shifted_origin = shifted_origin
        logging.debug(self)
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)

    assert_inputs_exist(parser, [args.in_odf, args.in_seed, args.in_mask])
    assert_outputs_exist(parser, args, args.out_tractogram)

    if not nib.streamlines.is_supported(args.out_tractogram):
        parser.error('Invalid output streamline file format (must be trk or ' +
                     'tck): {0}'.format(args.out_tractogram))

    verify_streamline_length_options(parser, args)
    verify_compression_th(args.compress)
    verify_seed_options(parser, args)

    mask_img = nib.load(args.in_mask)
    mask_data = get_data_as_mask(mask_img, dtype=bool)

    # Make sure the data is isotropic. Else, the strategy used
    # when providing information to dipy (i.e. working as if in voxel space)
    # will not yield correct results.
    odf_sh_img = nib.load(args.in_odf)
    if not np.allclose(np.mean(odf_sh_img.header.get_zooms()[:3]),
                       odf_sh_img.header.get_zooms()[0], atol=1e-03):
        parser.error(
            'ODF SH file is not isotropic. Tracking cannot be ran robustly.')

    if args.npv:
        nb_seeds = args.npv
        seed_per_vox = True
    elif args.nt:
        nb_seeds = args.nt
        seed_per_vox = False
    else:
        nb_seeds = 1
        seed_per_vox = True

    voxel_size = odf_sh_img.header.get_zooms()[0]
    vox_step_size = args.step_size / voxel_size
    seed_img = nib.load(args.in_seed)
    seeds = track_utils.random_seeds_from_mask(
        seed_img.get_fdata(dtype=np.float32),
        np.eye(4),
        seeds_count=nb_seeds,
        seed_count_per_voxel=seed_per_vox,
        random_seed=args.seed)

    # Tracking is performed in voxel space
    max_steps = int(args.max_length / args.step_size) + 1
    streamlines_generator = LocalTracking(
        _get_direction_getter(args),
        BinaryStoppingCriterion(mask_data),
        seeds, np.eye(4),
        step_size=vox_step_size, max_cross=1,
        maxlen=max_steps,
        fixedstep=True, return_all=True,
        random_seed=args.seed,
        save_seeds=args.save_seeds)

    scaled_min_length = args.min_length / voxel_size
    scaled_max_length = args.max_length / voxel_size

    if args.save_seeds:
        filtered_streamlines, seeds = \
            zip(*((s, p) for s, p in streamlines_generator
                  if scaled_min_length <= length(s) <= scaled_max_length))
        data_per_streamlines = {'seeds': lambda: seeds}
    else:
        filtered_streamlines = \
            (s for s in streamlines_generator
             if scaled_min_length <= length(s) <= scaled_max_length)
        data_per_streamlines = {}

    if args.compress:
        filtered_streamlines = (
            compress_streamlines(s, args.compress)
            for s in filtered_streamlines)

    tractogram = LazyTractogram(lambda: filtered_streamlines,
                                data_per_streamlines,
                                affine_to_rasmm=seed_img.affine)

    filetype = nib.streamlines.detect_format(args.out_tractogram)
    reference = get_reference_info(seed_img)
    header = create_tractogram_header(filetype, *reference)

    # Use generator to save the streamlines on-the-fly
    nib.streamlines.save(tractogram, args.out_tractogram, header=header)
def _processing_wrapper(args):
    hdf5_filename = args[0]
    labels_img = args[1]
    in_label, out_label = args[2]
    measures_to_compute = copy.copy(args[3])
    if args[4] is not None:
        similarity_directory = args[4][0]
    weighted = args[5]
    include_dps = args[6]

    hdf5_file = h5py.File(hdf5_filename, 'r')
    key = '{}_{}'.format(in_label, out_label)
    if key not in hdf5_file:
        return
    streamlines = reconstruct_streamlines_from_hdf5(hdf5_file, key)

    affine, dimensions, voxel_sizes, _ = get_reference_info(labels_img)
    measures_to_return = {}

    if not (np.allclose(hdf5_file.attrs['affine'], affine, atol=1e-03)
            and np.array_equal(hdf5_file.attrs['dimensions'], dimensions)):
        raise ValueError('Provided hdf5 have incompatible headers.')

    # Precompute to save one transformation, insert later
    if 'length' in measures_to_compute:
        streamlines_copy = list(streamlines)
        # scil_decompose_connectivity.py requires isotropic voxels
        mean_length = np.average(length(streamlines_copy)) * voxel_sizes[0]

    # If density is not required, do not compute it
    # Only required for volume, similarity and any metrics
    if not ((len(measures_to_compute) == 1 and
             ('length' in measures_to_compute
              or 'streamline_count' in measures_to_compute)) or
            (len(measures_to_compute) == 2 and
             ('length' in measures_to_compute
              and 'streamline_count' in measures_to_compute))):

        density = compute_tract_counts_map(streamlines, dimensions)

    if 'volume' in measures_to_compute:
        measures_to_return['volume'] = np.count_nonzero(density) * \
            np.prod(voxel_sizes)
        measures_to_compute.remove('volume')
    if 'streamline_count' in measures_to_compute:
        measures_to_return['streamline_count'] = len(streamlines)
        measures_to_compute.remove('streamline_count')
    if 'length' in measures_to_compute:
        measures_to_return['length'] = mean_length
        measures_to_compute.remove('length')
    if 'similarity' in measures_to_compute and similarity_directory:
        density_sim = load_node_nifti(similarity_directory, in_label,
                                      out_label, labels_img)
        if density_sim is None:
            ba_vox = 0
        else:
            ba_vox = compute_bundle_adjacency_voxel(density, density_sim)

        measures_to_return['similarity'] = ba_vox
        measures_to_compute.remove('similarity')

    for measure in measures_to_compute:
        if isinstance(measure, str) and os.path.isdir(measure):
            map_dirname = measure
            map_data = load_node_nifti(map_dirname, in_label, out_label,
                                       labels_img)
            measures_to_return[map_dirname] = np.average(
                map_data[map_data > 0])
        elif isinstance(measure, tuple) and os.path.isfile(measure[0]):
            metric_filename = measure[0]
            metric_img = measure[1]
            if not is_header_compatible(metric_img, labels_img):
                logging.error('{} do not have a compatible header'.format(
                    metric_filename))
                raise IOError

            metric_data = metric_img.get_fdata(dtype=np.float64)
            if weighted:
                density = density / np.max(density)
                voxels_value = metric_data * density
                voxels_value = voxels_value[voxels_value > 0]
            else:
                voxels_value = metric_data[density > 0]

            measures_to_return[metric_filename] = np.average(voxels_value)

    if include_dps:
        for dps_key in hdf5_file[key].keys():
            if dps_key not in ['data', 'offsets', 'lengths']:
                out_file = os.path.join(include_dps, dps_key)
                measures_to_return[out_file] = np.average(
                    hdf5_file[key][dps_key])

    return {(in_label, out_label): measures_to_return}
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)

    assert_inputs_exist(parser, [args.sh_file, args.seed_file, args.mask_file])
    assert_outputs_exist(parser, args, args.output_file)

    if not nib.streamlines.is_supported(args.output_file):
        parser.error('Invalid output streamline file format (must be trk or ' +
                     'tck): {0}'.format(args.output_file))

    if not args.min_length > 0:
        parser.error('minL must be > 0, {}mm was provided.'.format(
            args.min_length))
    if args.max_length < args.min_length:
        parser.error(
            'maxL must be > than minL, (minL={}mm, maxL={}mm).'.format(
                args.min_length, args.max_length))

    if args.compress:
        if args.compress < 0.001 or args.compress > 1:
            logging.warning(
                'You are using an error rate of {}.\nWe recommend setting it '
                'between 0.001 and 1.\n0.001 will do almost nothing to the '
                'tracts while 1 will higly compress/linearize the tracts'.
                format(args.compress))

    if args.npv and args.npv <= 0:
        parser.error('Number of seeds per voxel must be > 0.')

    if args.nt and args.nt <= 0:
        parser.error('Total number of seeds must be > 0.')

    mask_img = nib.load(args.mask_file)
    mask_data = mask_img.get_fdata()

    # Make sure the mask is isotropic. Else, the strategy used
    # when providing information to dipy (i.e. working as if in voxel space)
    # will not yield correct results.
    fodf_sh_img = nib.load(args.sh_file)
    if not np.allclose(np.mean(fodf_sh_img.header.get_zooms()[:3]),
                       fodf_sh_img.header.get_zooms()[0],
                       atol=1.e-3):
        parser.error(
            'SH file is not isotropic. Tracking cannot be ran robustly.')

    if args.npv:
        nb_seeds = args.npv
        seed_per_vox = True
    elif args.nt:
        nb_seeds = args.nt
        seed_per_vox = False
    else:
        nb_seeds = 1
        seed_per_vox = True

    voxel_size = fodf_sh_img.header.get_zooms()[0]
    vox_step_size = args.step_size / voxel_size
    seed_img = nib.load(args.seed_file)
    seeds = track_utils.random_seeds_from_mask(
        seed_img.get_fdata(),
        np.eye(4),
        seeds_count=nb_seeds,
        seed_count_per_voxel=seed_per_vox,
        random_seed=args.seed)

    # Tracking is performed in voxel space
    max_steps = int(args.max_length / args.step_size) + 1
    streamlines = LocalTracking(_get_direction_getter(args, mask_data),
                                BinaryStoppingCriterion(mask_data),
                                seeds,
                                np.eye(4),
                                step_size=vox_step_size,
                                max_cross=1,
                                maxlen=max_steps,
                                fixedstep=True,
                                return_all=True,
                                random_seed=args.seed,
                                save_seeds=args.save_seeds)

    scaled_min_length = args.min_length / voxel_size
    scaled_max_length = args.max_length / voxel_size

    if args.save_seeds:
        filtered_streamlines, seeds = \
            zip(*((s, p) for s, p in streamlines
                  if scaled_min_length <= length(s) <= scaled_max_length))
        data_per_streamlines = {'seeds': lambda: seeds}
    else:
        filtered_streamlines = \
            (s for s in streamlines
             if scaled_min_length <= length(s) <= scaled_max_length)
        data_per_streamlines = {}

    if args.compress:
        filtered_streamlines = (compress_streamlines(s, args.compress)
                                for s in filtered_streamlines)

    tractogram = LazyTractogram(lambda: filtered_streamlines,
                                data_per_streamlines,
                                affine_to_rasmm=seed_img.affine)

    filetype = nib.streamlines.detect_format(args.output_file)
    reference = get_reference_info(seed_img)
    header = create_tractogram_header(filetype, *reference)

    # Use generator to save the streamlines on-the-fly
    nib.streamlines.save(tractogram, args.output_file, header=header)
Example #22
0
def _processing_wrapper(args):
    hdf5_filename = args[0]
    labels_img = args[1]
    in_label, out_label = args[2]
    measures_to_compute = copy.copy(args[3])
    if args[4] is not None:
        similarity_directory = args[4][0]
    weighted = args[5]
    include_dps = args[6]
    min_lesion_vol = args[7]

    hdf5_file = h5py.File(hdf5_filename, 'r')
    key = '{}_{}'.format(in_label, out_label)
    if key not in hdf5_file:
        return
    streamlines = reconstruct_streamlines_from_hdf5(hdf5_file, key)
    if len(streamlines) == 0:
        return

    affine, dimensions, voxel_sizes, _ = get_reference_info(labels_img)
    measures_to_return = {}

    if not (np.allclose(hdf5_file.attrs['affine'], affine, atol=1e-03)
            and np.array_equal(hdf5_file.attrs['dimensions'], dimensions)):
        raise ValueError('Provided hdf5 have incompatible headers.')

    # Precompute to save one transformation, insert later
    if 'length' in measures_to_compute:
        streamlines_copy = list(streamlines)
        # scil_decompose_connectivity.py requires isotropic voxels
        mean_length = np.average(length(streamlines_copy))*voxel_sizes[0]

    # If density is not required, do not compute it
    # Only required for volume, similarity and any metrics
    if not ((len(measures_to_compute) == 1 and
             ('length' in measures_to_compute or
              'streamline_count' in measures_to_compute)) or
            (len(measures_to_compute) == 2 and
             ('length' in measures_to_compute and
              'streamline_count' in measures_to_compute))):

        density = compute_tract_counts_map(streamlines,
                                           dimensions)

    if 'volume' in measures_to_compute:
        measures_to_return['volume'] = np.count_nonzero(density) * \
            np.prod(voxel_sizes)
        measures_to_compute.remove('volume')
    if 'streamline_count' in measures_to_compute:
        measures_to_return['streamline_count'] = len(streamlines)
        measures_to_compute.remove('streamline_count')
    if 'length' in measures_to_compute:
        measures_to_return['length'] = mean_length
        measures_to_compute.remove('length')
    if 'similarity' in measures_to_compute and similarity_directory:
        density_sim = load_node_nifti(similarity_directory,
                                      in_label, out_label,
                                      labels_img)
        if density_sim is None:
            ba_vox = 0
        else:
            ba_vox = compute_bundle_adjacency_voxel(density, density_sim)

        measures_to_return['similarity'] = ba_vox
        measures_to_compute.remove('similarity')

    for measure in measures_to_compute:
        # Maps
        if isinstance(measure, str) and os.path.isdir(measure):
            map_dirname = measure
            map_data = load_node_nifti(map_dirname,
                                       in_label, out_label,
                                       labels_img)
            measures_to_return[map_dirname] = np.average(
                map_data[map_data > 0])
        elif isinstance(measure, tuple):
            if not isinstance(measure[0], tuple) \
                    and os.path.isfile(measure[0]):
                metric_filename = measure[0]
                metric_img = measure[1]
                if not is_header_compatible(metric_img, labels_img):
                    logging.error('{} do not have a compatible header'.format(
                        metric_filename))
                    raise IOError

                metric_data = metric_img.get_fdata(dtype=np.float64)
                if weighted:
                    avg_value = np.average(metric_data, weights=density)
                else:
                    avg_value = np.average(metric_data[density > 0])
                measures_to_return[metric_filename] = avg_value
            # lesion
            else:
                lesion_filename = measure[0][0]
                computed_lesion_labels = measure[0][1]
                lesion_img = measure[1]
                if not is_header_compatible(lesion_img, labels_img):
                    logging.error('{} do not have a compatible header'.format(
                        lesion_filename))
                    raise IOError

                voxel_sizes = lesion_img.header.get_zooms()[0:3]
                lesion_img.set_filename('tmp.nii.gz')
                lesion_atlas = get_data_as_label(lesion_img)
                tmp_dict = compute_lesion_stats(
                    density.astype(bool), lesion_atlas,
                    voxel_sizes=voxel_sizes, single_label=True,
                    min_lesion_vol=min_lesion_vol,
                    precomputed_lesion_labels=computed_lesion_labels)

                tmp_ind = _streamlines_in_mask(list(streamlines),
                                               lesion_atlas.astype(np.uint8),
                                               np.eye(3), [0, 0, 0])
                streamlines_count = len(
                    np.where(tmp_ind == [0, 1][True])[0].tolist())

                if tmp_dict:
                    measures_to_return[lesion_filename+'vol'] = \
                        tmp_dict['lesion_total_volume']
                    measures_to_return[lesion_filename+'count'] = \
                        tmp_dict['lesion_count']
                    measures_to_return[lesion_filename+'sc'] = \
                        streamlines_count
                else:
                    measures_to_return[lesion_filename+'vol'] = 0
                    measures_to_return[lesion_filename+'count'] = 0
                    measures_to_return[lesion_filename+'sc'] = 0

    if include_dps:
        for dps_key in hdf5_file[key].keys():
            if dps_key not in ['data', 'offsets', 'lengths']:
                out_file = os.path.join(include_dps, dps_key)
                if 'commit' in dps_key:
                    measures_to_return[out_file] = np.sum(
                        hdf5_file[key][dps_key])
                else:
                    measures_to_return[out_file] = np.average(
                        hdf5_file[key][dps_key])

    return {(in_label, out_label): measures_to_return}
Example #23
0
def compute_all_measures(args):
    tuple_1, tuple_2 = args[0]
    filename_1, reference_1 = tuple_1
    filename_2, reference_2 = tuple_2
    streamline_dice = args[1]
    disable_streamline_distance = args[2]

    if not is_header_compatible(reference_1, reference_2):
        raise ValueError('{} and {} have incompatible headers'.format(
            filename_1, filename_2))

    data_tuple_1 = load_data_tmp_saving([filename_1, reference_1, False,
                                         disable_streamline_distance])
    if data_tuple_1 is None:
        return None

    density_1, endpoints_density_1, bundle_1, \
        centroids_1 = data_tuple_1

    data_tuple_2 = load_data_tmp_saving([filename_2, reference_2, False,
                                         disable_streamline_distance])
    if data_tuple_2 is None:
        return None

    density_2, endpoints_density_2, bundle_2, \
        centroids_2 = data_tuple_2

    _, _, voxel_size, _ = get_reference_info(reference_1)
    voxel_size = np.product(voxel_size)

    # These measures are in mm^3
    binary_1 = copy.copy(density_1)
    binary_1[binary_1 > 0] = 1
    binary_2 = copy.copy(density_2)
    binary_2[binary_2 > 0] = 1
    volume_overlap = np.count_nonzero(binary_1 * binary_2)
    volume_overlap_endpoints = np.count_nonzero(
        endpoints_density_1 * endpoints_density_2)
    volume_overreach = np.abs(np.count_nonzero(
        binary_1 + binary_2) - volume_overlap)
    volume_overreach_endpoints = np.abs(np.count_nonzero(
        endpoints_density_1 + endpoints_density_2) - volume_overlap_endpoints)

    # These measures are in mm
    bundle_adjacency_voxel = compute_bundle_adjacency_voxel(density_1,
                                                            density_2,
                                                            non_overlap=True)
    if streamline_dice and not disable_streamline_distance:
        bundle_adjacency_streamlines = \
            compute_bundle_adjacency_streamlines(bundle_1,
                                                 bundle_2,
                                                 non_overlap=True)
    elif not disable_streamline_distance:
        bundle_adjacency_streamlines = \
            compute_bundle_adjacency_streamlines(bundle_1,
                                                 bundle_2,
                                                 centroids_1=centroids_1,
                                                 centroids_2=centroids_2,
                                                 non_overlap=True)
    # These measures are between 0 and 1
    dice_vox, w_dice_vox = compute_dice_voxel(density_1, density_2)

    dice_vox_endpoints, w_dice_vox_endpoints = compute_dice_voxel(
        endpoints_density_1,
        endpoints_density_2)
    density_correlation = compute_correlation(density_1, density_2)
    density_correlation_endpoints = compute_correlation(endpoints_density_1,
                                                        endpoints_density_2)

    measures_name = ['bundle_adjacency_voxels',
                     'dice_voxels', 'w_dice_voxels',
                     'volume_overlap',
                     'volume_overreach',
                     'dice_voxels_endpoints',
                     'w_dice_voxels_endpoints',
                     'volume_overlap_endpoints',
                     'volume_overreach_endpoints',
                     'density_correlation',
                     'density_correlation_endpoints']
    measures = [bundle_adjacency_voxel,
                dice_vox, w_dice_vox,
                volume_overlap * voxel_size,
                volume_overreach * voxel_size,
                dice_vox_endpoints,
                w_dice_vox_endpoints,
                volume_overlap_endpoints * voxel_size,
                volume_overreach_endpoints * voxel_size,
                density_correlation,
                density_correlation_endpoints]

    if not disable_streamline_distance:
        measures_name += ['bundle_adjacency_streamlines']
        measures += [bundle_adjacency_streamlines]

    # Only when the tractograms are exactly the same
    if streamline_dice:
        dice_streamlines, streamlines_intersect, streamlines_union = \
            compute_dice_streamlines(bundle_1, bundle_2)
        streamlines_count_overlap = len(streamlines_intersect)
        streamlines_count_overreach = len(
            streamlines_union) - len(streamlines_intersect)
        measures_name += ['dice_streamlines',
                          'streamlines_count_overlap',
                          'streamlines_count_overreach']
        measures += [dice_streamlines,
                     streamlines_count_overlap,
                     streamlines_count_overreach]

    return dict(zip(measures_name, measures))
Example #24
0
    def __init__(self, init_as=None, reference=None, store=None):
        """ Initialize an empty TrxFile, support preallocation """
        if init_as is not None:
            affine = init_as._zcontainer.attrs['VOXEL_TO_RASMM']
            dimensions = init_as._zcontainer.attrs['DIMENSIONS']
        elif reference is not None:
            affine, dimensions, _, _ = get_reference_info(reference)
        else:
            logging.debug('No reference provided, using blank space '
                          'attributes, please update them later.')
            affine = np.eye(4).astype(np.float32)
            dimensions = [1, 1, 1]

        if store is None:
            store = zarr.storage.TempStore()
        self._zcontainer = zarr.group(store=store, overwrite=True)
        self.voxel_to_rasmm = affine
        self.dimensions = dimensions
        self.nb_points = 0
        self.nb_streamlines = 0
        self._zstore = store

        if init_as:
            positions_dtype = init_as._zpos.dtype
        else:
            positions_dtype = np.float16
        self._zcontainer.create_dataset('positions',
                                        shape=(0, 3),
                                        chunks=(1000000, None),
                                        dtype=positions_dtype)

        self._zcontainer.create_dataset('offsets',
                                        shape=(0, ),
                                        chunks=(100000, ),
                                        dtype=np.uint64)

        self._zcontainer.create_group('data_per_point')
        self._zcontainer.create_group('data_per_streamline')
        self._zcontainer.create_group('data_per_group')
        self._zcontainer.create_group('groups')

        if init_as is None:
            return

        for dpp_key in init_as._zdpp.array_keys():
            empty_shape = list(init_as._zdpp[dpp_key].shape)
            empty_shape[0] = 0
            dtype = init_as._zdpp[dpp_key].dtype
            chunks = [1000000]
            for _ in range(len(empty_shape) - 1):
                chunks.append(None)

            self._zdpp.create_dataset(dpp_key,
                                      shape=empty_shape,
                                      chunks=chunks,
                                      dtype=dtype)

        for dps_key in init_as._zdps.array_keys():
            empty_shape = list(init_as._zdps[dps_key].shape)
            empty_shape[0] = 0
            dtype = init_as._zdps[dps_key].dtype
            chunks = [100000]
            for _ in range(len(empty_shape) - 1):
                chunks.append(None)

            self._zdps.create_dataset(dps_key,
                                      shape=empty_shape,
                                      chunks=chunks,
                                      dtype=dtype)

        for grp_key in init_as._zgrp.array_keys():
            empty_shape = list(init_as._zgrp[grp_key].shape)
            empty_shape[0] = 0
            dtype = init_as._zgrp[grp_key].dtype
            self._zgrp.create_dataset(grp_key,
                                      shape=empty_shape,
                                      chunks=(10000, ),
                                      dtype=dtype)

        for grp_key in init_as._zdpg.group_keys():
            if len(init_as._zdpg[grp_key]):
                self._zdpg.create_group(grp_key)
            for dpg_key in init_as._zdpg[grp_key].array_keys():
                empty_shape = list(init_as._zdpg[grp_key][dpg_key].shape)
                empty_shape[0] = 0
                dtype = init_as._zdpg[grp_key][dpg_key].dtype
                self._zdpg[grp_key].create_dataset(dpg_key,
                                                   shape=empty_shape,
                                                   chunks=None,
                                                   dtype=dtype)
Example #25
0
def transform_warp_streamlines(sft,
                               linear_transfo,
                               target,
                               inverse=False,
                               deformation_data=None,
                               remove_invalid=True,
                               cut_invalid=False):
    # TODO rename transform_warp_sft
    """ Transform tractogram using a affine Subsequently apply a warp from
    antsRegistration (optional).
    Remove/Cut invalid streamlines to preserve sft validity.

    Parameters
    ----------
    sft: StatefulTractogram
        Stateful tractogram object containing the streamlines to transform.
    linear_transfo: numpy.ndarray
        Linear transformation matrix to apply to the tractogram.
    target: Nifti filepath, image object, header
        Final reference for the tractogram after registration.
    inverse: boolean
        Apply the inverse linear transformation.
    deformation_data: np.ndarray
        4D array containing a 3D displacement vector in each voxel.

    remove_invalid: boolean
        Remove the streamlines landing out of the bounding box.
    cut_invalid: boolean
        Cut invalid streamlines rather than removing them. Keep the longest
        segment only.

    Return
    ----------
    new_sft : StatefulTractogram

    """
    sft.to_rasmm()
    sft.to_center()
    if inverse:
        linear_transfo = np.linalg.inv(linear_transfo)

    streamlines = transform_streamlines(sft.streamlines, linear_transfo)

    if deformation_data is not None:
        affine, _, _, _ = get_reference_info(target)

        # Because of duplication, an iteration over chunks of points is
        # necessary for a big dataset (especially if not compressed)
        streamlines = ArraySequence(streamlines)
        nb_points = len(streamlines._data)
        cur_position = 0
        chunk_size = 1000000
        nb_iteration = int(np.ceil(nb_points / chunk_size))
        inv_affine = np.linalg.inv(affine)

        while nb_iteration > 0:
            max_position = min(cur_position + chunk_size, nb_points)
            points = streamlines._data[cur_position:max_position]

            # To access the deformation information, we need to go in VOX space
            # No need for corner shift since we are doing interpolation
            cur_points_vox = np.array(transform_streamlines(
                points, inv_affine)).T

            x_def = map_coordinates(deformation_data[..., 0],
                                    cur_points_vox.tolist(),
                                    order=1)
            y_def = map_coordinates(deformation_data[..., 1],
                                    cur_points_vox.tolist(),
                                    order=1)
            z_def = map_coordinates(deformation_data[..., 2],
                                    cur_points_vox.tolist(),
                                    order=1)

            # ITK is in LPS and nibabel is in RAS, a flip is necessary for ANTs
            final_points = np.array([-1 * x_def, -1 * y_def, z_def])
            final_points += np.array(points).T

            streamlines._data[cur_position:max_position] = final_points.T
            cur_position = max_position
            nb_iteration -= 1

    new_sft = StatefulTractogram(streamlines,
                                 target,
                                 Space.RASMM,
                                 data_per_point=sft.data_per_point,
                                 data_per_streamline=sft.data_per_streamline)
    if cut_invalid:
        new_sft, _ = cut_invalid_streamlines(new_sft)
    elif remove_invalid:
        new_sft.remove_invalid_streamlines()

    return new_sft
Example #26
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)

    assert_inputs_exist(
        parser,
        [args.in_sh, args.in_seed, args.in_map_include, args.map_exclude_file])
    assert_outputs_exist(parser, args, args.out_tractogram)

    if not nib.streamlines.is_supported(args.out_tractogram):
        parser.error('Invalid output streamline file format (must be trk or ' +
                     'tck): {0}'.format(args.out_tractogram))

    if not args.min_length > 0:
        parser.error('minL must be > 0, {}mm was provided.'.format(
            args.min_length))
    if args.max_length < args.min_length:
        parser.error(
            'maxL must be > than minL, (minL={}mm, maxL={}mm).'.format(
                args.min_length, args.max_length))

    if args.compress:
        if args.compress < 0.001 or args.compress > 1:
            logging.warning(
                'You are using an error rate of {}.\nWe recommend setting it '
                'between 0.001 and 1.\n0.001 will do almost nothing to the '
                'tracts while 1 will higly compress/linearize the tracts'.
                format(args.compress))

    if args.particles <= 0:
        parser.error('--particles must be >= 1.')

    if args.back_tracking <= 0:
        parser.error('PFT backtracking distance must be > 0.')

    if args.forward_tracking <= 0:
        parser.error('PFT forward tracking distance must be > 0.')

    if args.npv and args.npv <= 0:
        parser.error('Number of seeds per voxel must be > 0.')

    if args.nt and args.nt <= 0:
        parser.error('Total number of seeds must be > 0.')

    fodf_sh_img = nib.load(args.in_sh)
    if not np.allclose(np.mean(fodf_sh_img.header.get_zooms()[:3]),
                       fodf_sh_img.header.get_zooms()[0],
                       atol=1.e-3):
        parser.error(
            'SH file is not isotropic. Tracking cannot be ran robustly.')

    tracking_sphere = HemiSphere.from_sphere(get_sphere('repulsion724'))

    # Check if sphere is unit, since we couldn't find such check in Dipy.
    if not np.allclose(np.linalg.norm(tracking_sphere.vertices, axis=1), 1.):
        raise RuntimeError('Tracking sphere should be unit normed.')

    sh_basis = args.sh_basis

    if args.algo == 'det':
        dgklass = DeterministicMaximumDirectionGetter
    else:
        dgklass = ProbabilisticDirectionGetter

    theta = get_theta(args.theta, args.algo)

    # Reminder for the future:
    # pmf_threshold == clip pmf under this
    # relative_peak_threshold is for initial directions filtering
    # min_separation_angle is the initial separation angle for peak extraction
    dg = dgklass.from_shcoeff(fodf_sh_img.get_fdata(dtype=np.float32),
                              max_angle=theta,
                              sphere=tracking_sphere,
                              basis_type=sh_basis,
                              pmf_threshold=args.sf_threshold,
                              relative_peak_threshold=args.sf_threshold_init)

    map_include_img = nib.load(args.in_map_include)
    map_exclude_img = nib.load(args.map_exclude_file)
    voxel_size = np.average(map_include_img.header['pixdim'][1:4])

    if not args.act:
        tissue_classifier = CmcStoppingCriterion(
            map_include_img.get_fdata(dtype=np.float32),
            map_exclude_img.get_fdata(dtype=np.float32),
            step_size=args.step_size,
            average_voxel_size=voxel_size)
    else:
        tissue_classifier = ActStoppingCriterion(
            map_include_img.get_fdata(dtype=np.float32),
            map_exclude_img.get_fdata(dtype=np.float32))

    if args.npv:
        nb_seeds = args.npv
        seed_per_vox = True
    elif args.nt:
        nb_seeds = args.nt
        seed_per_vox = False
    else:
        nb_seeds = 1
        seed_per_vox = True

    voxel_size = fodf_sh_img.header.get_zooms()[0]
    vox_step_size = args.step_size / voxel_size
    seed_img = nib.load(args.in_seed)
    seeds = track_utils.random_seeds_from_mask(
        get_data_as_mask(seed_img, dtype=np.bool),
        np.eye(4),
        seeds_count=nb_seeds,
        seed_count_per_voxel=seed_per_vox,
        random_seed=args.seed)

    # Note that max steps is used once for the forward pass, and
    # once for the backwards. This doesn't, in fact, control the real
    # max length
    max_steps = int(args.max_length / args.step_size) + 1
    pft_streamlines = ParticleFilteringTracking(
        dg,
        tissue_classifier,
        seeds,
        np.eye(4),
        max_cross=1,
        step_size=vox_step_size,
        maxlen=max_steps,
        pft_back_tracking_dist=args.back_tracking,
        pft_front_tracking_dist=args.forward_tracking,
        particle_count=args.particles,
        return_all=args.keep_all,
        random_seed=args.seed,
        save_seeds=args.save_seeds)

    scaled_min_length = args.min_length / voxel_size
    scaled_max_length = args.max_length / voxel_size

    if args.save_seeds:
        filtered_streamlines, seeds = \
            zip(*((s, p) for s, p in pft_streamlines
                  if scaled_min_length <= length(s) <= scaled_max_length))
        data_per_streamlines = {'seeds': lambda: seeds}
    else:
        filtered_streamlines = \
            (s for s in pft_streamlines
             if scaled_min_length <= length(s) <= scaled_max_length)
        data_per_streamlines = {}

    if args.compress:
        filtered_streamlines = (compress_streamlines(s, args.compress)
                                for s in filtered_streamlines)

    tractogram = LazyTractogram(lambda: filtered_streamlines,
                                data_per_streamlines,
                                affine_to_rasmm=seed_img.affine)

    filetype = nib.streamlines.detect_format(args.out_tractogram)
    reference = get_reference_info(seed_img)
    header = create_tractogram_header(filetype, *reference)

    # Use generator to save the streamlines on-the-fly
    nib.streamlines.save(tractogram, args.out_tractogram, header=header)
Example #27
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_tractogram, args.in_labels],
                        args.reference)
    assert_outputs_exist(parser, args, args.out_hdf5)

    # HDF5 will not overwrite the file
    if os.path.isfile(args.out_hdf5):
        os.remove(args.out_hdf5)

    if (args.save_raw_connections or args.save_intermediate
            or args.save_discarded) and not args.out_dir:
        parser.error('To save outputs in the streamlines form, provide the '
                     'output directory using --out_dir.')

    if args.out_dir:
        if os.path.abspath(args.out_dir) == os.getcwd():
            parser.error('Do not use the current path as output directory.')
        assert_output_dirs_exist_and_empty(parser,
                                           args,
                                           args.out_dir,
                                           create_dir=True)

    log_level = logging.WARNING
    if args.verbose:
        log_level = logging.INFO
    logging.basicConfig(level=log_level)
    coloredlogs.install(level=log_level)
    set_sft_logger_level('WARNING')

    img_labels = nib.load(args.in_labels)
    data_labels = get_data_as_label(img_labels)
    real_labels = np.unique(data_labels)[1:]
    if args.out_labels_list:
        np.savetxt(args.out_labels_list, real_labels, fmt='%i')

    # Voxel size must be isotropic, for speed/performance considerations
    vox_sizes = img_labels.header.get_zooms()
    if not np.allclose(np.mean(vox_sizes), vox_sizes, atol=1e-03):
        parser.error('Labels must be isotropic')

    logging.info('*** Loading streamlines ***')
    time1 = time.time()
    sft = load_tractogram_with_reference(parser,
                                         args,
                                         args.in_tractogram,
                                         bbox_check=False)
    sft.remove_invalid_streamlines()
    time2 = time.time()
    logging.info('    Loading {} streamlines took {} sec.'.format(
        len(sft), round(time2 - time1, 2)))

    if not is_header_compatible(sft, img_labels):
        raise IOError('{} and {}do not have a compatible header'.format(
            args.in_tractogram, args.in_labels))

    sft.to_vox()
    sft.to_corner()
    # Get all streamlines intersection indices
    logging.info('*** Computing streamlines intersection ***')
    time1 = time.time()

    indices, points_to_idx = uncompress(sft.streamlines, return_mapping=True)

    time2 = time.time()
    logging.info('    Streamlines intersection took {} sec.'.format(
        round(time2 - time1, 2)))

    # Compute the connectivity mapping
    logging.info('*** Computing connectivity information ***')
    time1 = time.time()
    con_info = compute_connectivity(indices, data_labels, real_labels,
                                    extract_longest_segments_from_profile)
    time2 = time.time()
    logging.info('    Connectivity computation took {} sec.'.format(
        round(time2 - time1, 2)))

    # Prepare directories and information needed to save.
    _create_required_output_dirs(args)

    logging.info('*** Starting connection post-processing and saving. ***')
    logging.info('    This can be long, be patient.')
    time1 = time.time()

    # Saving will be done from streamlines already in the right space
    comb_list = list(itertools.combinations(real_labels, r=2))
    comb_list.extend(zip(real_labels, real_labels))

    iteration_counter = 0
    with h5py.File(args.out_hdf5, 'w') as hdf5_file:
        affine, dimensions, voxel_sizes, voxel_order = get_reference_info(sft)
        hdf5_file.attrs['affine'] = affine
        hdf5_file.attrs['dimensions'] = dimensions
        hdf5_file.attrs['voxel_sizes'] = voxel_sizes
        hdf5_file.attrs['voxel_order'] = voxel_order

        # Each connections is processed independently. Multiprocessing would be
        # a burden on the I/O of most SSD/HD
        for in_label, out_label in comb_list:
            if iteration_counter > 0 and iteration_counter % 100 == 0:
                logging.info('Split {} nodes out of {}'.format(
                    iteration_counter, len(comb_list)))
            iteration_counter += 1

            pair_info = []
            if in_label not in con_info:
                continue
            elif out_label in con_info[in_label]:
                pair_info.extend(con_info[in_label][out_label])

            if out_label not in con_info:
                continue
            elif in_label in con_info[out_label]:
                pair_info.extend(con_info[out_label][in_label])

            if not len(pair_info):
                continue

            connecting_streamlines = []
            connecting_ids = []
            for connection in pair_info:
                strl_idx = connection['strl_idx']
                curr_streamlines = compute_streamline_segment(
                    sft.streamlines[strl_idx], indices[strl_idx],
                    connection['in_idx'], connection['out_idx'],
                    points_to_idx[strl_idx])
                connecting_streamlines.append(curr_streamlines)
                connecting_ids.append(strl_idx)

            # Each step is processed from the previous 'success'
            #   1. raw         -> length pass/fail
            #   2. length pass -> loops pass/fail
            #   3. loops pass  -> outlier detection pass/fail
            #   4. outlier detection pass -> qb curvature pass/fail
            #   5. qb curvature pass == final connections
            connecting_streamlines = ArraySequence(connecting_streamlines)
            raw_dps = sft.data_per_streamline[connecting_ids]
            raw_sft = StatefulTractogram.from_sft(connecting_streamlines,
                                                  sft,
                                                  data_per_streamline=raw_dps,
                                                  data_per_point={})
            _save_if_needed(raw_sft, hdf5_file, args, 'raw', 'raw', in_label,
                            out_label)

            # Doing all post-processing
            if not args.no_pruning:
                valid_length_ids, invalid_length_ids = _prune_segments(
                    raw_sft.streamlines, args.min_length, args.max_length,
                    vox_sizes[0])

                invalid_length_sft = raw_sft[invalid_length_ids]
                valid_length = connecting_streamlines[valid_length_ids]
                _save_if_needed(invalid_length_sft, hdf5_file, args,
                                'discarded', 'invalid_length', in_label,
                                out_label)
            else:
                valid_length = connecting_streamlines
                valid_length_ids = range(len(connecting_streamlines))

            if not len(valid_length):
                continue

            valid_length_sft = raw_sft[valid_length_ids]
            _save_if_needed(valid_length_sft, hdf5_file, args, 'intermediate',
                            'valid_length', in_label, out_label)

            if not args.no_remove_loops:
                no_loop_ids = remove_loops_and_sharp_turns(
                    valid_length, args.loop_max_angle)
                loop_ids = np.setdiff1d(np.arange(len(valid_length)),
                                        no_loop_ids)

                loops_sft = valid_length_sft[loop_ids]
                no_loops = valid_length[no_loop_ids]
                _save_if_needed(loops_sft, hdf5_file, args, 'discarded',
                                'loops', in_label, out_label)
            else:
                no_loops = valid_length
                no_loop_ids = range(len(valid_length))

            if not len(no_loops):
                continue
            no_loops_sft = valid_length_sft[no_loop_ids]
            _save_if_needed(no_loops_sft, hdf5_file, args, 'intermediate',
                            'no_loops', in_label, out_label)

            if not args.no_remove_outliers:
                outliers_ids, inliers_ids = remove_outliers(
                    no_loops,
                    args.outlier_threshold,
                    nb_samplings=10,
                    fast_approx=True)

                outliers_sft = no_loops_sft[outliers_ids]
                inliers = no_loops[inliers_ids]
                _save_if_needed(outliers_sft, hdf5_file, args, 'discarded',
                                'outliers', in_label, out_label)
            else:
                inliers = no_loops
                inliers_ids = range(len(no_loops))

            if not len(inliers):
                continue

            inliers_sft = no_loops_sft[inliers_ids]
            _save_if_needed(inliers_sft, hdf5_file, args, 'intermediate',
                            'inliers', in_label, out_label)

            if not args.no_remove_curv_dev:
                no_qb_curv_ids = remove_loops_and_sharp_turns(
                    inliers,
                    args.loop_max_angle,
                    use_qb=True,
                    qb_threshold=args.curv_qb_distance)
                qb_curv_ids = np.setdiff1d(np.arange(len(inliers)),
                                           no_qb_curv_ids)

                qb_curv_sft = inliers_sft[qb_curv_ids]
                _save_if_needed(qb_curv_sft, hdf5_file, args, 'discarded',
                                'qb_curv', in_label, out_label)
            else:
                no_qb_curv_ids = range(len(inliers))

            no_qb_curv_sft = inliers_sft[no_qb_curv_ids]
            _save_if_needed(no_qb_curv_sft, hdf5_file, args, 'final', 'final',
                            in_label, out_label)

    time2 = time.time()
    logging.info(
        '    Connections post-processing and saving took {} sec.'.format(
            round(time2 - time1, 2)))
Example #28
0
def alltrks2bin(tractseg_folder):
    reference_anatomy = tractseg_folder + 'reference_anatomy.nii.gz'
    affine, dimensions, voxel_sizes, voxel_order = get_reference_info(
        reference_anatomy)
    low_res = (2.5, 2.5, 2.5)

    subjects = []
    for root, dirs, files in walk(tractseg_folder):
        if 'average' in dirs: dirs.remove('average')
        subjects.extend(dirs)
        break

    # assuming all subject folders contain all tracts (given for TractSeg samples)
    tracts = []
    for root, dirs, files in walk(tractseg_folder + subjects[0] + '/tracts'):
        for f in files:
            if f[-4:] == ".trk":
                tracts.append(f[:-4])

    # make a folder to save average masks into
    try:
        Path(tractseg_folder + 'average').mkdir(parents=True, exist_ok=True)
    except OSError:
        print('Could not create output dir. Aborting...')
        return

    i = 0
    errs = []
    for t in tracts:
        print('Current tract: ', t, ' (', (i * 100) / len(tracts), '%)')
        high_res_path = tractseg_folder + 'average/' + t + '.nii.gz'
        low_res_path = tractseg_folder + 'average/' + t + '_low_res.nii.gz'

        if not (Path(high_res_path).exists() & Path(low_res_path).exists()):
            j = 0
            t_ave = np.zeros(dimensions, dtype=np.float)
            for s in subjects:
                try:
                    print('Current subject: ', s, ' (',
                          (j * 100) / len(subjects), '%)')
                    trk_data = load_trk(
                        tractseg_folder + s + '/tracts/' + t + '.trk',
                        reference_anatomy)
                    trk_data.to_vox()
                    streams = np.vstack(trk_data.streamlines).astype(np.int32)
                    mask = np.zeros(dimensions, dtype=np.float)
                    mask[streams[:, 0], streams[:, 1], streams[:, 2]] = 1
                    save_nifti(
                        tractseg_folder + s + '/tracts/' + t + '.nii.gz', mask,
                        affine)

                    # make low res version
                    # mask_low_res, affine_low_res = reslice(mask, affine, voxel_sizes, low_res)
                    # save_nifti(tractseg_folder + s + '/tracts/' + t + '_low_res.nii.gz', mask_low_res, affine_low_res)

                    t_ave += mask
                    j += 1
                except:
                    errs.append('Error with tract ' + t + ' in subject ' + s)

            try:
                t_ave /= len(subjects)
                save_nifti(high_res_path, t_ave, affine)
                t_ave_low_res, affine_low_res = reslice(
                    t_ave, affine, voxel_sizes, low_res)
                save_nifti(low_res_path, t_ave_low_res, affine_low_res)
            except:
                errs.append('Error with average of tract ' + t)
        i += 1

    for e in errs:
        print(e)
Example #29
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_hdf5, args.in_target_file,
                                 args.in_transfo], args.in_deformation)
    assert_outputs_exist(parser, args, args.out_hdf5)

    # HDF5 will not overwrite the file
    if os.path.isfile(args.out_hdf5):
        os.remove(args.out_hdf5)

    with h5py.File(args.in_hdf5, 'r') as in_hdf5_file:
        shutil.copy(args.in_hdf5, args.out_hdf5)
        with h5py.File(args.out_hdf5, 'a') as out_hdf5_file:
            transfo = load_matrix_in_any_format(args.in_transfo)

            deformation_data = None
            if args.in_deformation is not None:
                deformation_data = np.squeeze(nib.load(
                    args.in_deformation).get_fdata(dtype=np.float32))
            target_img = nib.load(args.in_target_file)

            for key in in_hdf5_file.keys():
                affine = in_hdf5_file.attrs['affine']
                dimensions = in_hdf5_file.attrs['dimensions']
                voxel_sizes = in_hdf5_file.attrs['voxel_sizes']
                streamlines = reconstruct_streamlines_from_hdf5(
                    in_hdf5_file, key)

                if len(streamlines) == 0:
                    continue

                header = create_nifti_header(affine, dimensions, voxel_sizes)
                moving_sft = StatefulTractogram(streamlines, header, Space.VOX,
                                                origin=Origin.TRACKVIS)

                new_sft = transform_warp_streamlines(
                    moving_sft, transfo, target_img,
                    inverse=args.inverse,
                    deformation_data=deformation_data,
                    remove_invalid=not args.cut_invalid,
                    cut_invalid=args.cut_invalid)
                new_sft.to_vox()
                new_sft.to_corner()

                affine, dimensions, voxel_sizes, voxel_order = get_reference_info(
                    target_img)
                out_hdf5_file.attrs['affine'] = affine
                out_hdf5_file.attrs['dimensions'] = dimensions
                out_hdf5_file.attrs['voxel_sizes'] = voxel_sizes
                out_hdf5_file.attrs['voxel_order'] = voxel_order

                group = out_hdf5_file[key]
                del group['data']
                group.create_dataset('data',
                                     data=new_sft.streamlines.get_data())
                del group['offsets']
                group.create_dataset('offsets',
                                     data=new_sft.streamlines._offsets)
                del group['lengths']
                group.create_dataset('lengths',
                                     data=new_sft.streamlines._lengths)
Example #30
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()
    if args.load_transfo and args.in_native_fa is None:
        parser.error('When loading a transformation, the final reference is '
                     'needed, use --in_native_fa.')
    assert_inputs_exist(parser, [args.in_dsi_tractogram, args.in_dsi_fa],
                        optional=args.in_native_fa)
    assert_outputs_exist(parser, args, args.out_tractogram)

    sft = load_tractogram(args.in_dsi_tractogram,
                          'same',
                          bbox_valid_check=False)

    # LPS -> RAS convention in voxel space
    sft.to_vox()
    flip_axis = ['x', 'y']
    sft_fix = StatefulTractogram(sft.streamlines, args.in_dsi_fa, Space.VOXMM)
    sft_fix.to_vox()
    sft_fix.streamlines._data -= get_axis_shift_vector(flip_axis)

    sft_flip = flip_sft(sft_fix, flip_axis)

    sft_flip.to_rasmm()
    sft_flip.streamlines._data -= [0.5, 0.5, -0.5]

    if not args.in_native_fa:
        if args.cut_invalid:
            sft_flip, _ = cut_invalid_streamlines(sft_flip)
        elif args.remove_invalid:
            sft_flip.remove_invalid_streamlines()
        save_tractogram(sft_flip,
                        args.out_tractogram,
                        bbox_valid_check=not args.keep_invalid)
    else:
        static_img = nib.load(args.in_native_fa)
        static_data = static_img.get_fdata()
        moving_img = nib.load(args.in_dsi_fa)
        moving_data = moving_img.get_fdata()

        # DSI-Studio flips the volume without changing the affine (I think)
        # So this has to be reversed (not the same problem as above)
        vox_order = get_reference_info(moving_img)[3]
        flip_axis = []
        if vox_order[0] == 'L':
            moving_data = moving_data[::-1, :, :]
            flip_axis.append('x')
        if vox_order[1] == 'P':
            moving_data = moving_data[:, ::-1, :]
            flip_axis.append('y')
        if vox_order[2] == 'I':
            moving_data = moving_data[:, :, ::-1]
            flip_axis.append('z')
        sft_flip_back = flip_sft(sft_flip, flip_axis)

        if args.load_transfo:
            transfo = np.loadtxt(args.load_transfo)
        else:
            # Sometimes DSI studio has quite a lot of skull left
            # Dipy Median Otsu does not work with FA/GFA
            if args.auto_crop:
                moving_data = cube_crop_data(moving_data)
                static_data = cube_crop_data(static_data)

            # Since DSI Studio register to AC/PC and does not save the
            # transformation We must estimate the transformation, since it's
            # rigid it is 'easy'
            c_of_mass = transform_centers_of_mass(static_data,
                                                  static_img.affine,
                                                  moving_data,
                                                  moving_img.affine)

            nbins = 32
            sampling_prop = None
            level_iters = [1000, 100, 10]
            sigmas = [3.0, 2.0, 1.0]
            factors = [3, 2, 1]
            metric = MutualInformationMetric(nbins, sampling_prop)
            affreg = AffineRegistration(metric=metric,
                                        level_iters=level_iters,
                                        sigmas=sigmas,
                                        factors=factors)
            transform = RigidTransform3D()
            rigid = affreg.optimize(static_data,
                                    moving_data,
                                    transform,
                                    None,
                                    static_img.affine,
                                    moving_img.affine,
                                    starting_affine=c_of_mass.affine)
            transfo = rigid.affine
            if args.save_transfo:
                np.savetxt(args.save_transfo, transfo)

        new_sft = transform_warp_sft(sft_flip_back,
                                     transfo,
                                     static_img,
                                     inverse=True,
                                     remove_invalid=args.remove_invalid,
                                     cut_invalid=args.cut_invalid)

        if args.cut_invalid:
            new_sft, _ = cut_invalid_streamlines(new_sft)
        elif args.remove_invalid:
            new_sft.remove_invalid_streamlines()
        save_tractogram(new_sft,
                        args.out_tractogram,
                        bbox_valid_check=not args.keep_invalid)