Exemplo n.º 1
0
    def remove_invalid_streamlines(self):
        """ Remove streamlines with invalid coordinates from the object.
        Will also remove the data_per_point and data_per_streamline.
        Invalid coordinates are any X,Y,Z values above the reference
        dimensions or below zero
        Returns
        -------
        output : tuple
            Tuple of two list, indices_to_remove, indices_to_keep
        """
        old_space = deepcopy(self.space)
        old_shift = deepcopy(self.shifted_origin)

        self.to_vox()
        self.to_corner()

        min_condition = np.min(self._tractogram.streamlines.data, axis=1) < 0.0
        max_condition = np.any(
            self._tractogram.streamlines.data > self._dimensions, axis=1)
        ic_offsets_indices = np.where(
            np.logical_or(min_condition, max_condition))[0]

        indices_to_remove = []
        for i in ic_offsets_indices:
            indices_to_remove.append(
                bisect(self._tractogram.streamlines._offsets, i) - 1)

        indices_to_keep = np.setdiff1d(np.arange(len(self._tractogram)),
                                       np.array(indices_to_remove)).astype(int)

        tmp_streamlines = \
            itemgetter(*indices_to_keep)(self.get_streamlines_copy())
        tmp_data_per_point = {}
        tmp_data_per_streamline = {}

        for key in self._tractogram.data_per_point:
            tmp_data_per_point[key] = \
                self._tractogram.data_per_point[key][indices_to_keep]

        for key in self._tractogram.data_per_streamline:
            tmp_data_per_streamline[key] = \
                self._tractogram.data_per_streamline[key][indices_to_keep]

        self._tractogram = Tractogram(tmp_streamlines,
                                      affine_to_rasmm=np.eye(4))

        self._tractogram.data_per_point = tmp_data_per_point
        self._tractogram.data_per_streamline = tmp_data_per_streamline

        if old_space == Space.RASMM:
            self.to_rasmm()
        elif old_space == Space.VOXMM:
            self.to_voxmm()

        if not old_shift:
            self.to_center()

        return indices_to_remove, indices_to_keep
Exemplo n.º 2
0
    def remove_invalid_streamlines(self, epsilon=1e-3):
        """ Remove streamlines with invalid coordinates from the object.
        Will also remove the data_per_point and data_per_streamline.
        Invalid coordinates are any X,Y,Z values above the reference
        dimensions or below zero

        Parameters
        ----------
        epsilon : float (optional)
            Epsilon value for the bounding box verification.
            Default is 1e-6.

        Returns
        -------
        output : tuple
            Tuple of two list, indices_to_remove, indices_to_keep
        """
        if not self.streamlines:
            return

        old_space = deepcopy(self.space)
        old_origin = deepcopy(self.origin)

        self.to_vox()
        self.to_corner()

        min_condition = np.min(self._tractogram.streamlines._data,
                               axis=1) < epsilon
        max_condition = np.any(self._tractogram.streamlines._data >
                               self._dimensions-epsilon, axis=1)
        ic_offsets_indices = np.where(np.logical_or(min_condition,
                                                    max_condition))[0]

        indices_to_remove = []
        for i in ic_offsets_indices:
            indices_to_remove.append(bisect(
                self._tractogram.streamlines._offsets, i) - 1)

        indices_to_remove = sorted(set(indices_to_remove))

        indices_to_keep = list(
            np.setdiff1d(np.arange(len(self._tractogram)),
                         np.array(indices_to_remove)).astype(int))

        tmp_streamlines = self.streamlines[indices_to_keep]
        tmp_dpp = self._tractogram.data_per_point[indices_to_keep]
        tmp_dps = self._tractogram.data_per_streamline[indices_to_keep]

        self._tractogram = Tractogram(tmp_streamlines.copy(),
                                      data_per_point=tmp_dpp,
                                      data_per_streamline=tmp_dps,
                                      affine_to_rasmm=np.eye(4))

        self.to_space(old_space)
        self.to_origin(old_origin)

        return indices_to_remove, indices_to_keep
Exemplo n.º 3
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_bundle)
    assert_outputs_exist(parser, args, args.out_bundle, args.remaining_bundle)
    if args.alpha <= 0 or args.alpha > 1:
        parser.error('--alpha should be ]0, 1]')

    tractogram = nib.streamlines.load(args.in_bundle)

    if int(tractogram.header['nb_streamlines']) == 0:
        logging.warning("Bundle file contains no streamline")
        return

    check_tracts_same_format(
        parser, [args.in_bundle, args.out_bundle, args.remaining_bundle])

    streamlines = tractogram.streamlines

    summary = outliers_removal_using_hierarchical_quickbundles(streamlines)
    outliers, inliers = prune(streamlines, args.alpha, summary)

    inliers_streamlines = tractogram.streamlines[inliers]
    inliers_data_per_streamline = tractogram.tractogram.data_per_streamline[
        inliers]
    inliers_data_per_point = tractogram.tractogram.data_per_point[inliers]

    outliers_streamlines = tractogram.streamlines[outliers]
    outliers_data_per_streamline = tractogram.tractogram.data_per_streamline[
        outliers]
    outliers_data_per_point = tractogram.tractogram.data_per_point[outliers]

    if len(inliers_streamlines) == 0:
        logging.warning("All streamlines are considered outliers."
                        "Please lower the --alpha parameter")
    else:
        inliers_tractogram = Tractogram(
            inliers_streamlines,
            affine_to_rasmm=np.eye(4),
            data_per_streamline=inliers_data_per_streamline,
            data_per_point=inliers_data_per_point)
        nib.streamlines.save(inliers_tractogram,
                             args.out_bundle,
                             header=tractogram.header)

    if len(outliers_streamlines) == 0:
        logging.warning("No outlier found. Please raise the --alpha parameter")
    elif args.remaining_bundle:
        outlier_tractogram = Tractogram(
            outliers_streamlines,
            affine_to_rasmm=np.eye(4),
            data_per_streamline=outliers_data_per_streamline,
            data_per_point=outliers_data_per_point)
        nib.streamlines.save(outlier_tractogram,
                             args.remaining_bundle,
                             header=tractogram.header)
Exemplo n.º 4
0
    def remove_invalid_streamlines(self):
        """ Remove streamlines with invalid coordinates from the object.
        Will also remove the data_per_point and data_per_streamline.
        Invalid coordinates are any X,Y,Z values above the reference
        dimensions or below zero
        Returns
        -------
        output : tuple
            Tuple of two list, indices_to_remove, indices_to_keep
        """
        if not self.streamlines:
            return

        old_space = deepcopy(self.space)
        old_origin = deepcopy(self.origin)

        self.to_vox()
        self.to_corner()

        min_condition = np.min(self._tractogram.streamlines._data,
                               axis=1) < 0.0
        max_condition = np.any(
            self._tractogram.streamlines._data > self._dimensions, axis=1)
        ic_offsets_indices = np.where(
            np.logical_or(min_condition, max_condition))[0]

        indices_to_remove = []
        for i in ic_offsets_indices:
            indices_to_remove.append(
                bisect(self._tractogram.streamlines._offsets, i) - 1)

        indices_to_keep = np.setdiff1d(np.arange(len(self._tractogram)),
                                       np.array(indices_to_remove)).astype(int)

        tmp_streamlines = self.streamlines[indices_to_keep]
        tmp_data_per_point = self._tractogram.data_per_point[indices_to_keep]
        tmp_data_per_streamline =\
            self._tractogram.data_per_streamline[indices_to_keep]

        self._tractogram = Tractogram(
            tmp_streamlines.copy(),
            data_per_point=tmp_data_per_point,
            data_per_streamline=tmp_data_per_streamline,
            affine_to_rasmm=np.eye(4))

        self.to_space(old_space)
        self.to_origin(old_origin)

        return indices_to_remove, indices_to_keep
Exemplo n.º 5
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.bundle])
    assert_outputs_exists(parser, args, [args.pruned_bundle])

    if args.min_length < 0:
        parser.error('--min_length {} should be at least 0'.format(
            args.min_length))
    if args.max_length <= args.min_length:
        parser.error(
            '--max_length {} should be greater than --min_length'.format(
                args.max_length))

    tractogram = nib.streamlines.load(args.bundle)
    streamlines = tractogram.streamlines
    pruned_streamlines = subsample_streamlines(streamlines, args.min_length,
                                               args.max_length)

    if not pruned_streamlines:
        print("Pruning removed all the streamlines. Please adjust "
              "--{min,max}_length")
    else:
        pruned_tractogram = Tractogram(pruned_streamlines,
                                       affine_to_rasmm=np.eye(4))
        nib.streamlines.save(pruned_tractogram,
                             args.pruned_bundle,
                             header=tractogram.header)
Exemplo n.º 6
0
def get_centroid_streamline(tractogram, nb_points, distance_threshold):
    streamlines = tractogram.streamlines
    resample_feature = ResampleFeature(nb_points=nb_points)
    quick_bundle = QuickBundles(
        threshold=distance_threshold,
        metric=AveragePointwiseEuclideanMetric(resample_feature))
    clusters = quick_bundle.cluster(streamlines)
    centroid_streamlines = clusters.centroids

    if len(centroid_streamlines) > 1:
        raise Exception('Multiple centroids found')

    return Tractogram(centroid_streamlines, affine_to_rasmm=np.eye(4))
Exemplo n.º 7
0
def save_tractogram(sft, filename, bbox_valid_check=True):
    """ Save the stateful tractogram in any format (trk, tck, vtk, fib, dpy)

    Parameters
    ----------
    sft : StatefulTractogram
        The stateful tractogram to save
    filename : string
        Filename with valid extension

    Returns
    -------
    output : bool
        Did the saving work properly
    """

    _, extension = os.path.splitext(filename)
    if extension not in ['.trk', '.tck', '.vtk', '.fib', '.dpy']:
        TypeError('Output filename is not one of the supported format')

    if bbox_valid_check and not sft.is_bbox_in_vox_valid():
        raise ValueError('Bounding box is not valid in voxel space, cannot ' +
                         'save a valid file if some coordinates are invalid')

    old_space = deepcopy(sft.space)
    old_shift = deepcopy(sft.shifted_origin)

    sft.to_rasmm()
    sft.to_center()

    timer = time.time()
    if extension in ['.trk', '.tck']:
        tractogram_type = detect_format(filename)
        header = create_tractogram_header(tractogram_type,
                                          *sft.space_attribute)
        new_tractogram = Tractogram(sft.streamlines,
                                    affine_to_rasmm=np.eye(4))

        if extension == '.trk':
            new_tractogram.data_per_point = sft.data_per_point
            new_tractogram.data_per_streamline = sft.data_per_streamline

        fileobj = tractogram_type(new_tractogram, header=header)
        nib.streamlines.save(fileobj, filename)

    elif extension in ['.vtk', '.fib']:
        save_vtk_streamlines(sft.streamlines, filename, binary=True)
    elif extension in ['.dpy']:
        dpy_obj = Dpy(filename, mode='w')
        dpy_obj.write_tracks(sft.streamlines)
        dpy_obj.close()

    logging.debug('Save %s with %s streamlines in %s seconds',
                  filename, len(sft), round(time.time() - timer, 3))

    if old_space == Space.VOX:
        sft.to_vox()
    elif old_space == Space.VOXMM:
        sft.to_voxmm()

    if old_shift:
        sft.to_corner()

    return True
Exemplo n.º 8
0
    def __init__(self, streamlines, reference, space,
                 origin=Origin.NIFTI,
                 data_per_point=None, data_per_streamline=None):
        """ Create a strict, state-aware, robust tractogram

        Parameters
        ----------
        streamlines : list or ArraySequence
            Streamlines of the tractogram
        reference : Nifti or Trk filename, Nifti1Image or TrkFile,
            Nifti1Header, trk.header (dict) or another Stateful Tractogram
            Reference that provides the spatial attributes.
            Typically a nifti-related object from the native diffusion used for
            streamlines generation
        space : Enum (dipy.io.stateful_tractogram.Space)
            Current space in which the streamlines are (vox, voxmm or rasmm)
            After tracking the space is VOX, after loading with nibabel
            the space is RASMM
        origin : Enum (dipy.io.stateful_tractogram.Origin), optional
            Current origin in which the streamlines are (center or corner)
            After loading with nibabel the origin is CENTER
        data_per_point : dict, optional
            Dictionary in which each key has X items, each items has Y_i items
            X being the number of streamlines
            Y_i being the number of points on streamlines #i
        data_per_streamline : dict, optional
            Dictionary in which each key has X items
            X being the number of streamlines

        Notes
        -----
        Very important to respect the convention, verify that streamlines
        match the reference and are effectively in the right space.

        Any change to the number of streamlines, data_per_point or
        data_per_streamline requires particular verification.

        In a case of manipulation not allowed by this object, use Nibabel
        directly and be careful.
        """
        if data_per_point is None:
            data_per_point = {}

        if data_per_streamline is None:
            data_per_streamline = {}

        if isinstance(streamlines, Streamlines):
            streamlines = streamlines.copy()
        self._tractogram = Tractogram(streamlines,
                                      data_per_point=data_per_point,
                                      data_per_streamline=data_per_streamline)

        if isinstance(reference, type(self)):
            logger.warning('Using a StatefulTractogram as reference, this '
                           'will copy only the space_attributes, not '
                           'the state. The variables space and origin '
                           'must be specified separately.')
            logger.warning('To copy the state from another StatefulTractogram '
                           'you may want to use the function from_sft '
                           '(static function of the StatefulTractogram).')

        if isinstance(reference, tuple) and len(reference) == 4:
            if is_reference_info_valid(*reference):
                space_attributes = reference
            else:
                raise TypeError('The provided space attributes are not '
                                'considered valid, please correct before '
                                'using them with StatefulTractogram.')
        else:
            space_attributes = get_reference_info(reference)
            if space_attributes is None:
                raise TypeError('Reference MUST be one of the following:\n'
                                'Nifti or Trk filename, Nifti1Image or '
                                'TrkFile, Nifti1Header or trk.header (dict).')

        (self._affine, self._dimensions,
         self._voxel_sizes, self._voxel_order) = space_attributes
        self._inv_affine = np.linalg.inv(self._affine)

        if space not in Space:
            raise ValueError('Space MUST be from Space enum, e.g Space.VOX.')
        self._space = space

        if origin not in Origin:
            raise ValueError('Origin MUST be from Origin enum, '
                             'e.g Origin.NIFTI.')
        self._origin = origin
        logger.debug(self)
Exemplo n.º 9
0
class StatefulTractogram(object):
    """ Class for stateful representation of collections of streamlines
    Object designed to be identical no matter the file format
    (trk, tck, vtk, fib, dpy). Facilitate transformation between space and
    data manipulation for each streamline / point.
    """

    def __init__(self, streamlines, reference, space,
                 origin=Origin.NIFTI,
                 data_per_point=None, data_per_streamline=None):
        """ Create a strict, state-aware, robust tractogram

        Parameters
        ----------
        streamlines : list or ArraySequence
            Streamlines of the tractogram
        reference : Nifti or Trk filename, Nifti1Image or TrkFile,
            Nifti1Header, trk.header (dict) or another Stateful Tractogram
            Reference that provides the spatial attributes.
            Typically a nifti-related object from the native diffusion used for
            streamlines generation
        space : Enum (dipy.io.stateful_tractogram.Space)
            Current space in which the streamlines are (vox, voxmm or rasmm)
            After tracking the space is VOX, after loading with nibabel
            the space is RASMM
        origin : Enum (dipy.io.stateful_tractogram.Origin), optional
            Current origin in which the streamlines are (center or corner)
            After loading with nibabel the origin is CENTER
        data_per_point : dict, optional
            Dictionary in which each key has X items, each items has Y_i items
            X being the number of streamlines
            Y_i being the number of points on streamlines #i
        data_per_streamline : dict, optional
            Dictionary in which each key has X items
            X being the number of streamlines

        Notes
        -----
        Very important to respect the convention, verify that streamlines
        match the reference and are effectively in the right space.

        Any change to the number of streamlines, data_per_point or
        data_per_streamline requires particular verification.

        In a case of manipulation not allowed by this object, use Nibabel
        directly and be careful.
        """
        if data_per_point is None:
            data_per_point = {}

        if data_per_streamline is None:
            data_per_streamline = {}

        if isinstance(streamlines, Streamlines):
            streamlines = streamlines.copy()
        self._tractogram = Tractogram(streamlines,
                                      data_per_point=data_per_point,
                                      data_per_streamline=data_per_streamline)

        if isinstance(reference, type(self)):
            logger.warning('Using a StatefulTractogram as reference, this '
                           'will copy only the space_attributes, not '
                           'the state. The variables space and origin '
                           'must be specified separately.')
            logger.warning('To copy the state from another StatefulTractogram '
                           'you may want to use the function from_sft '
                           '(static function of the StatefulTractogram).')

        if isinstance(reference, tuple) and len(reference) == 4:
            if is_reference_info_valid(*reference):
                space_attributes = reference
            else:
                raise TypeError('The provided space attributes are not '
                                'considered valid, please correct before '
                                'using them with StatefulTractogram.')
        else:
            space_attributes = get_reference_info(reference)
            if space_attributes is None:
                raise TypeError('Reference MUST be one of the following:\n'
                                'Nifti or Trk filename, Nifti1Image or '
                                'TrkFile, Nifti1Header or trk.header (dict).')

        (self._affine, self._dimensions,
         self._voxel_sizes, self._voxel_order) = space_attributes
        self._inv_affine = np.linalg.inv(self._affine)

        if space not in Space:
            raise ValueError('Space MUST be from Space enum, e.g Space.VOX.')
        self._space = space

        if origin not in Origin:
            raise ValueError('Origin MUST be from Origin enum, '
                             'e.g Origin.NIFTI.')
        self._origin = origin
        logger.debug(self)

    @staticmethod
    def are_compatible(sft_1, sft_2):
        """ Compatibility verification of two StatefulTractogram to ensure space,
        origin, data_per_point and data_per_streamline consistency """

        are_sft_compatible = True
        if not is_header_compatible(sft_1, sft_2):
            logger.warning('Inconsistent spatial attributes between both sft.')
            are_sft_compatible = False

        if sft_1.space != sft_2.space:
            logger.warning('Inconsistent space between both sft.')
            are_sft_compatible = False
        if sft_1.origin != sft_2.origin:
            logger.warning('Inconsistent origin between both sft.')
            are_sft_compatible = False

        if sft_1.get_data_per_point_keys() != sft_2.get_data_per_point_keys():
            logger.warning(
                'Inconsistent data_per_point between both sft.')
            are_sft_compatible = False
        if sft_1.get_data_per_streamline_keys() != \
                sft_2.get_data_per_streamline_keys():
            logger.warning(
                'Inconsistent data_per_streamline between both sft.')
            are_sft_compatible = False

        return are_sft_compatible

    @staticmethod
    def from_sft(streamlines, sft,
                 data_per_point=None,
                 data_per_streamline=None):
        """ Create an instance of `StatefulTractogram` from another instance
        of `StatefulTractogram`.

        Parameters
        ----------
        streamlines : list or ArraySequence
            Streamlines of the tractogram
        sft : StatefulTractgram,
            The other StatefulTractgram to copy the space_attribute AND
            state from.
        data_per_point : dict, optional
            Dictionary in which each key has X items, each items has Y_i items
            X being the number of streamlines
            Y_i being the number of points on streamlines #i
        data_per_streamline : dict, optional
            Dictionary in which each key has X items
            X being the number of streamlines
        -----
        """
        new_sft = StatefulTractogram(streamlines,
                                     sft.space_attributes,
                                     sft.space,
                                     origin=sft.origin,
                                     data_per_point=data_per_point,
                                     data_per_streamline=data_per_streamline)
        return new_sft

    def __str__(self):
        """ Generate the string for printing """
        text = 'Affine: \n{}'.format(
            np.array2string(self._affine,
                            formatter={'float_kind': lambda x: "%.6f" % x}))
        text += '\ndimensions: {}'.format(
            np.array2string(self._dimensions))
        text += '\nvoxel_sizes: {}'.format(
            np.array2string(self._voxel_sizes,
                            formatter={'float_kind': lambda x: "%.2f" % x}))
        text += '\nvoxel_order: {}'.format(self._voxel_order)

        text += '\nstreamline_count: {}'.format(self._get_streamline_count())
        text += '\npoint_count: {}'.format(self._get_point_count())
        text += '\ndata_per_streamline keys: {}'.format(
            self.get_data_per_streamline_keys())
        text += '\ndata_per_point keys: {}'.format(
            self.get_data_per_point_keys())

        return text

    def __len__(self):
        """ Define the length of the object """
        return self._get_streamline_count()

    def __getitem__(self, key):
        """ Slice all data in a consistent way """
        if isinstance(key, int):
            key = [key]

        return self.from_sft(self.streamlines[key], self,
                             data_per_point=self.data_per_point[key],
                             data_per_streamline=self.data_per_streamline[key])

    def __eq__(self, other):
        """ Robust StatefulTractogram equality test """
        if not self.are_compatible(self, other):
            return False

        streamlines_equal = np.allclose(self.streamlines.get_data(),
                                        other.streamlines.get_data())
        if not streamlines_equal:
            return False

        dpp_equal = True
        for key in self.data_per_point:
            dpp_equal = dpp_equal and np.allclose(
                self.data_per_point[key].get_data(),
                other.data_per_point[key].get_data())
        if not dpp_equal:
            return False

        dps_equal = True
        for key in self.data_per_streamline:
            dps_equal = dps_equal and np.allclose(
                self.data_per_streamline[key],
                other.data_per_streamline[key])
        if not dps_equal:
            return False

        return True

    def __ne__(self, other):
        """ Robust StatefulTractogram equality test (NOT) """
        return not self == other

    def __add__(self, other_sft):
        """ Addition of two sft with attributes consistency checks """
        if not self.are_compatible(self, other_sft):
            logger.debug(self)
            logger.debug(other_sft)
            raise ValueError('Inconsistent StatefulTractogram.\n'
                             'Make sure Space, Origin are the same and that '
                             'data_per_point and data_per_streamline keys are '
                             'the same.')

        streamlines = self.streamlines.copy()
        streamlines.extend(other_sft.streamlines)

        data_per_point = deepcopy(self.data_per_point)
        data_per_point.extend(other_sft.data_per_point)

        data_per_streamline = deepcopy(self.data_per_streamline)
        data_per_streamline.extend(other_sft.data_per_streamline)

        return self.from_sft(streamlines, self,
                             data_per_point=data_per_point,
                             data_per_streamline=data_per_streamline)

    def __iadd__(self, other):
        self.value = self + other
        return self.value

    @property
    def space_attributes(self):
        """ Getter for spatial attribute """
        return self._affine, self._dimensions, self._voxel_sizes, \
            self._voxel_order

    @property
    def space(self):
        """ Getter for the current space """
        return self._space

    @property
    def affine(self):
        """ Getter for the reference affine """
        return self._affine

    @property
    def dimensions(self):
        """ Getter for the reference dimensions """
        return self._dimensions

    @property
    def voxel_sizes(self):
        """ Getter for the reference voxel sizes """
        return self._voxel_sizes

    @property
    def voxel_order(self):
        """ Getter for the reference voxel order """
        return self._voxel_order

    @property
    def origin(self):
        """ Getter for origin standard """
        return self._origin

    @property
    def streamlines(self):
        """ Partially safe getter for streamlines """
        return self._tractogram.streamlines

    def get_streamlines_copy(self):
        """ Safe getter for streamlines (for slicing) """
        return self._tractogram.streamlines.copy()

    @streamlines.setter
    def streamlines(self, streamlines):
        """ Modify streamlines. Creating a new object would be less risky.

        Parameters
        ----------
        streamlines : list or ArraySequence (list and deepcopy recommanded)
            Streamlines of the tractogram
        """
        if isinstance(streamlines, Streamlines):
            streamlines = streamlines.copy()
        self._tractogram._streamlines = Streamlines(streamlines)
        self.data_per_point = self.data_per_point
        self.data_per_streamline = self.data_per_streamline
        logger.warning('Streamlines has been modified.')

    @property
    def data_per_point(self):
        """ Getter for data_per_point """
        return self._tractogram.data_per_point

    @data_per_point.setter
    def data_per_point(self, data):
        """ Modify point data . Creating a new object would be less risky.

        Parameters
        ----------
        data : dict
            Dictionary in which each key has X items, each items has Y_i items
            X being the number of streamlines
            Y_i being the number of points on streamlines #i
        """
        self._tractogram.data_per_point = data
        logger.warning('Data_per_point has been modified.')

    @property
    def data_per_streamline(self):
        """ Getter for data_per_streamline """
        return self._tractogram.data_per_streamline

    @data_per_streamline.setter
    def data_per_streamline(self, data):
        """ Modify point data . Creating a new object would be less risky.

        Parameters
        ----------
        data : dict
            Dictionary in which each key has X items, each items has Y_i items
            X being the number of streamlines
        """
        self._tractogram.data_per_streamline = data
        logger.warning('Data_per_streamline has been modified.')

    def get_data_per_point_keys(self):
        """ Return a list of the data_per_point attribute names """
        return list(self.data_per_point.keys())

    def get_data_per_streamline_keys(self):
        """ Return a list of the data_per_streamline attribute names """
        return list(self.data_per_streamline.keys())

    def to_vox(self):
        """ Safe function to transform streamlines and update state """
        if self._space == Space.VOXMM:
            self._voxmm_to_vox()
        elif self._space == Space.RASMM:
            self._rasmm_to_vox()

    def to_voxmm(self):
        """ Safe function to transform streamlines and update state """
        if self._space == Space.VOX:
            self._vox_to_voxmm()
        elif self._space == Space.RASMM:
            self._rasmm_to_voxmm()

    def to_rasmm(self):
        """ Safe function to transform streamlines and update state """
        if self._space == Space.VOX:
            self._vox_to_rasmm()
        elif self._space == Space.VOXMM:
            self._voxmm_to_rasmm()

    def to_space(self, target_space):
        """ Safe function to transform streamlines to a particular space using
        an enum and update state """
        if target_space == Space.VOX:
            self.to_vox()
        elif target_space == Space.VOXMM:
            self.to_voxmm()
        elif target_space == Space.RASMM:
            self.to_rasmm()
        else:
            logger.error('Unsupported target space, please use Enum in '
                         'dipy.io.stateful_tractogram.')

    def to_origin(self, target_origin):
        """ Safe function to change streamlines to a particular origin standard
        False means NIFTI (center) and True means TrackVis (corner) """
        if target_origin == Origin.NIFTI:
            self.to_center()
        elif target_origin == Origin.TRACKVIS:
            self.to_corner()
        else:
            logger.error('Unsupported origin standard, please use Enum in '
                         'dipy.io.stateful_tractogram.')

    def to_center(self):
        """ Safe function to shift streamlines so the center of voxel is
        the origin """
        if self._origin == Origin.TRACKVIS:
            self._shift_voxel_origin()

    def to_corner(self):
        """ Safe function to shift streamlines so the corner of voxel is
        the origin """
        if self._origin == Origin.NIFTI:
            self._shift_voxel_origin()

    def compute_bounding_box(self):
        """ Compute the bounding box of the streamlines in their current state

        Returns
        -------
        output : ndarray
            8 corners of the XYZ aligned box, all zeros if no streamlines
        """
        if self._tractogram.streamlines._data.size > 0:
            bbox_min = np.min(self._tractogram.streamlines._data, axis=0)
            bbox_max = np.max(self._tractogram.streamlines._data, axis=0)
            return np.asarray(list(product(*zip(bbox_min, bbox_max))))

        return np.zeros((8, 3))

    def is_bbox_in_vox_valid(self):
        """ Verify that the bounding box is valid in voxel space.
        Negative coordinates or coordinates above the volume dimensions
        are considered invalid in voxel space.

        Returns
        -------
        output : bool
            Are the streamlines within the volume of the associated reference
        """
        if not self.streamlines:
            return True

        old_space = deepcopy(self.space)
        old_origin = deepcopy(self.origin)

        # Do to rotation, equivalent of a OBB must be done
        self.to_vox()
        self.to_corner()
        bbox_corners = deepcopy(self.compute_bounding_box())

        is_valid = True
        if np.any(bbox_corners < 0):
            logger.error('Voxel space values lower than 0.0.')
            logger.debug(bbox_corners)
            is_valid = False

        if np.any(bbox_corners[:, 0] > self._dimensions[0]) or \
                np.any(bbox_corners[:, 1] > self._dimensions[1]) or \
                np.any(bbox_corners[:, 2] > self._dimensions[2]):
            logger.error('Voxel space values higher than dimensions.')
            logger.debug(bbox_corners)
            is_valid = False

        self.to_space(old_space)
        self.to_origin(old_origin)

        return is_valid

    def remove_invalid_streamlines(self, epsilon=1e-3):
        """ Remove streamlines with invalid coordinates from the object.
        Will also remove the data_per_point and data_per_streamline.
        Invalid coordinates are any X,Y,Z values above the reference
        dimensions or below zero

        Parameters
        ----------
        epsilon : float (optional)
            Epsilon value for the bounding box verification.
            Default is 1e-6.

        Returns
        -------
        output : tuple
            Tuple of two list, indices_to_remove, indices_to_keep
        """
        if not self.streamlines:
            return

        old_space = deepcopy(self.space)
        old_origin = deepcopy(self.origin)

        self.to_vox()
        self.to_corner()

        min_condition = np.min(self._tractogram.streamlines._data,
                               axis=1) < epsilon
        max_condition = np.any(self._tractogram.streamlines._data >
                               self._dimensions-epsilon, axis=1)
        ic_offsets_indices = np.where(np.logical_or(min_condition,
                                                    max_condition))[0]

        indices_to_remove = []
        for i in ic_offsets_indices:
            indices_to_remove.append(bisect(
                self._tractogram.streamlines._offsets, i) - 1)

        indices_to_remove = sorted(set(indices_to_remove))

        indices_to_keep = list(
            np.setdiff1d(np.arange(len(self._tractogram)),
                         np.array(indices_to_remove)).astype(int))

        tmp_streamlines = self.streamlines[indices_to_keep]
        tmp_dpp = self._tractogram.data_per_point[indices_to_keep]
        tmp_dps = self._tractogram.data_per_streamline[indices_to_keep]

        self._tractogram = Tractogram(tmp_streamlines.copy(),
                                      data_per_point=tmp_dpp,
                                      data_per_streamline=tmp_dps,
                                      affine_to_rasmm=np.eye(4))

        self.to_space(old_space)
        self.to_origin(old_origin)

        return indices_to_remove, indices_to_keep

    def _get_streamline_count(self):
        """ Safe getter for the number of streamlines """
        return len(self._tractogram)

    def _get_point_count(self):
        """ Safe getter for the number of streamlines """
        return self._tractogram.streamlines.total_nb_rows

    def _vox_to_voxmm(self):
        """ Unsafe function to transform streamlines """
        if self._space == Space.VOX:
            if self._tractogram.streamlines._data.size > 0:
                self._tractogram.streamlines._data *= np.asarray(
                    self._voxel_sizes)
                self._space = Space.VOXMM
                logger.debug('Moved streamlines from vox to voxmm.')
        else:
            logger.warning('Wrong initial space for this function.')
            return

    def _voxmm_to_vox(self):
        """ Unsafe function to transform streamlines """
        if self._space == Space.VOXMM:
            if self._tractogram.streamlines._data.size > 0:
                self._tractogram.streamlines._data /= np.asarray(
                    self._voxel_sizes)
                self._space = Space.VOX
                logger.debug('Moved streamlines from voxmm to vox.')
        else:
            logger.warning('Wrong initial space for this function.')
            return

    def _vox_to_rasmm(self):
        """ Unsafe function to transform streamlines """
        if self._space == Space.VOX:
            if self._tractogram.streamlines._data.size > 0:
                self._tractogram.apply_affine(self._affine)
                self._space = Space.RASMM
                logger.debug('Moved streamlines from vox to rasmm.')
        else:
            logger.warning('Wrong initial space for this function.')
            return

    def _rasmm_to_vox(self):
        """ Unsafe function to transform streamlines """
        if self._space == Space.RASMM:
            if self._tractogram.streamlines._data.size > 0:
                self._tractogram.apply_affine(self._inv_affine)
                self._space = Space.VOX
                logger.debug('Moved streamlines from rasmm to vox.')
        else:
            logger.warning('Wrong initial space for this function.')
            return

    def _voxmm_to_rasmm(self):
        """ Unsafe function to transform streamlines """
        if self._space == Space.VOXMM:
            if self._tractogram.streamlines._data.size > 0:
                self._tractogram.streamlines._data /= np.asarray(
                    self._voxel_sizes)
                self._tractogram.apply_affine(self._affine)
                self._space = Space.RASMM
                logger.debug('Moved streamlines from voxmm to rasmm.')
        else:
            logger.warning('Wrong initial space for this function.')
            return

    def _rasmm_to_voxmm(self):
        """ Unsafe function to transform streamlines """
        if self._space == Space.RASMM:
            if self._tractogram.streamlines._data.size > 0:
                self._tractogram.apply_affine(self._inv_affine)
                self._tractogram.streamlines._data *= np.asarray(
                    self._voxel_sizes)
                self._space = Space.VOXMM
                logger.debug('Moved streamlines from rasmm to voxmm.')
        else:
            logger.warning('Wrong initial space for this function.')
            return

    def _shift_voxel_origin(self):
        """ Unsafe function to switch the origin from center to corner
        and vice versa """
        if not self.streamlines:
            return

        shift = np.asarray([0.5, 0.5, 0.5])
        if self._space == Space.VOXMM:
            shift = shift * self._voxel_sizes
        elif self._space == Space.RASMM:
            tmp_affine = np.eye(4)
            tmp_affine[0:3, 0:3] = self._affine[0:3, 0:3]
            shift = apply_affine(tmp_affine, shift)
        if self._origin == Origin.TRACKVIS:
            shift *= -1

        self._tractogram.streamlines._data += shift
        if self._origin == Origin.NIFTI:
            logger.debug('Origin moved to the corner of voxel.')
            self._origin = Origin.TRACKVIS
        else:
            logger.debug('Origin moved to the center of voxel.')
            self._origin = Origin.NIFTI
Exemplo n.º 10
0
    def __init__(self,
                 streamlines,
                 reference,
                 space,
                 shifted_origin=False,
                 data_per_point=None,
                 data_per_streamline=None):
        """ Create a strict, state-aware, robust tractogram

        Parameters
        ----------
        streamlines : list or ArraySequence
            Streamlines of the tractogram
        reference : Nifti or Trk filename, Nifti1Image or TrkFile,
            Nifti1Header, trk.header (dict) or another Stateful Tractogram
            Reference that provides the spatial attributes.
            Typically a nifti-related object from the native diffusion used for
            streamlines generation
        space : Enum (dipy.io.stateful_tractogram.Space)
            Current space in which the streamlines are (vox, voxmm or rasmm)
            Typically after tracking the space is VOX, after nibabel loading
            the space is RASMM
        shifted_origin : bool
            Information on the position of the origin,
            False is Trackvis standard, default (center of the voxel)
            True is NIFTI standard (corner of the voxel)
        data_per_point : dict
            Dictionary in which each key has X items, each items has Y_i items
            X being the number of streamlines
            Y_i being the number of points on streamlines #i
        data_per_streamline : dict
            Dictionary in which each key has X items
            X being the number of streamlines

        Notes
        -----
        Very important to respect the convention, verify that streamlines
        match the reference and are effectively in the right space.

        Any change to the number of streamlines, data_per_point or
        data_per_streamline requires particular verification.

        In a case of manipulation not allowed by this object, use Nibabel
        directly and be careful.
        """
        if data_per_point is None:
            data_per_point = {}

        if data_per_streamline is None:
            data_per_streamline = {}

        if isinstance(streamlines, Streamlines):
            streamlines = streamlines.copy()
        self._tractogram = Tractogram(streamlines,
                                      data_per_point=data_per_point,
                                      data_per_streamline=data_per_streamline)

        space_attributes = get_reference_info(reference)
        if space_attributes is None:
            raise TypeError('Reference MUST be one of the following:\n' +
                            'Nifti or Trk filename, Nifti1Image or TrkFile, ' +
                            'Nifti1Header or trk.header (dict)')

        (self._affine, self._dimensions, self._voxel_sizes,
         self._voxel_order) = space_attributes
        self._inv_affine = np.linalg.inv(self._affine)

        if space not in Space:
            raise ValueError('Space MUST be from Space enum, e.g Space.VOX')
        self._space = space

        if not isinstance(shifted_origin, bool):
            raise TypeError('shifted_origin MUST be a boolean')
        self._shifted_origin = shifted_origin
        logging.debug(self)
Exemplo n.º 11
0
def lossy_compression_of_tractogram(tractogramfile, outdir, rate=0.392,
                                    search_optimal_rate=False,
                                    weightsfile=None, weights_thr=0.,
                                    max_search_dist=2.2, verbose=0):
    """ Reduce the number of points of the track by keeping
    intact the start and endpoints of the track and trying to remove
    as many points as possible without distorting much the shape of
    the track, ie. more points in curvy regions and less points in less curvy
    regions.

    Parameters
    ----------
    tractogramfile: str
        the path to the tractogram.
    outdir: str
        the destination folder.
    rate: float, default 0.392
        the compression rate, ie. smoothing parameter (<0.392 smoother,
        >0.392 rougher).
    search_optimal_rate: bool, default False
        determine the optimal compression rate.
    weightsfile: str, default None
        use these weights to remove unsignificant streamlines.
    weights_thr: float, default 0.
        the threshold used to identify unsignificant streamlines.
    max_search_dist: float, default 2.2
        the maximum distance between the initial and downsampled streamlines
        allowed during the best rate search.
    verbose: int, default 0
        the verbosity level.

    Returns
    -------
    compressed_tractogramfile: str
        the compressed tractogram.
    nb_points_file: str
        the compression result compared to the original sampling.
    """
    # Load the tractogram
    trk = nibabel.streamlines.load(tractogramfile)
    if verbose > 0:
        print("[info] Number of tracks: {0}".format(len(trk.streamlines)))

    # Keep only significant streamlines
    tracks = trk.streamlines
    if weightsfile is not None:
        weights = numpy.loadtxt(weightsfile)
        keep_indices = numpy.where(weights > weights_thr)[0]
        tracks = list(numpy.array(tracks)[keep_indices])
        weights = weights[numpy.where(keep_indices)[0]]
        if verbose > 0:
            print("[info] Number of significant tracks: {0}".format(
                len(tracks)))
    else:
        weights = None

    # Compress tractogram
    # > dynamic compression rate
    if search_optimal_rate:
        rate = "dynamic"
        ref_lengths = list(length(tracks))
        rates = numpy.linspace(1, 0, 21)
        opt_lengths = numpy.zeros((len(rates), len(tracks)))
        for idx, optrate in enumerate(rates):
            if verbose > 0:
                print("[info] Grid search at rate '{0}'.".format(optrate))
            decimated_tracks = [
                approx_polygon_track(t, optrate) for t in tracks]
            opt_lengths[idx] = list(length(decimated_tracks))
            opt_lengths[idx] -= ref_lengths
        opt_lengths = numpy.abs(opt_lengths)
        if verbose > 2:
            print("[debug] Optimal lengths: {0}".format(opt_lengths))
        opt_lengths[numpy.where(opt_lengths > max_search_dist)] = 0
        opt_rate_indices = numpy.argmax(opt_lengths, axis=0)
        if verbose > 2:
            print("[debug] Optimal rate indices: {0}".format(opt_rate_indices))
        tracks = [approx_polygon_track(t, rates[i])
                  for t, i in zip(tracks, opt_rate_indices)]
    # > static compression rate
    else:
        tracks = [approx_polygon_track(t, rate) for t in tracks]
    compressed_tractogramfile = os.path.join(
        outdir, "compressed_tractogram.trk")
    compressed_trk = Tractogram(
        streamlines=tracks,
        affine_to_rasmm=trk.affine)
    nibabel.streamlines.save(compressed_trk, compressed_tractogramfile)

    # Summary graph
    n_pts_initial = [len(t) for t in trk.streamlines]
    n_pts_compressed = [len(t) for t in compressed_trk.streamlines]
    nb_points_file = os.path.join(outdir, "nb_points.png")
    fig, ax = plt.subplots(1)
    ax.hist(n_pts_initial, color="r", histtype="step", label="initial")
    ax.hist(n_pts_compressed, color="b", histtype="step",
            label="compressed ({0})".format(rate))
    ax.set_xlabel("Number of points")
    ax.set_ylabel("Count")
    plt.legend()
    plt.savefig(nb_points_file)

    return compressed_tractogramfile, nb_points_file
Exemplo n.º 12
0
def mark(config, gpu_queue=None):

    gpu_idx = -1
    try:
        gpu_idx = maybe_get_a_gpu() if gpu_queue is None else gpu_queue.get()
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu_idx
    except Exception as e:
        print(str(e))
    print("Loading DWI data ...")

    dwi_img = nib.load(config["dwi_path"])
    dwi_img = nib.funcs.as_closest_canonical(dwi_img)
    dwi_aff = dwi_img.affine
    dwi_affi = np.linalg.inv(dwi_aff)
    dwi = dwi_img.get_data()

    def xyz2ijk(coords, snap=False):

        ijk = (coords.T).copy()

        ijk = np.vstack([ijk, np.ones([1, ijk.shape[1]])])

        dwi_affi.dot(ijk, out=ijk)

        if snap:
            return (np.round(ijk, out=ijk).astype(int, copy=False).T)[:, :4]
        else:
            return (ijk.T)[:, :4]

    # ==========================================================================

    print("Loading fibers ...")

    trk_file = nib.streamlines.load(config["trk_path"])
    tractogram = trk_file.tractogram

    if "t" in tractogram.data_per_point:
        print("Fibers are already resampled")
        tangents = tractogram.data_per_point["t"]
    else:
        print("Fibers are not resampled. Resampling now ...")
        tractogram = maybe_add_tangent(config["trk_path"],
                                       min_length=30,
                                       max_length=200)
        tangents = tractogram.data_per_point["t"]

    n_fibers = len(tractogram)
    fiber_lengths = np.array([len(t.streamline) for t in tractogram])
    max_length = fiber_lengths.max()
    n_pts = fiber_lengths.sum()

    # ==========================================================================

    print("Loading model ...")
    model_name = config['model_name']

    if hasattr(MODELS[model_name], "custom_objects"):
        model = load_model(config["model_path"],
                           custom_objects=MODELS[model_name].custom_objects,
                           compile=False)
    else:
        model = load_model(config["model_path"], compile=False)

    block_size = get_blocksize(model, dwi.shape[-1])

    d = np.zeros([n_fibers, dwi.shape[-1] * block_size**3 + 1])

    inputs = np.zeros([n_fibers, max_length, 3])

    print("Writing to input array ...")

    for i, fiber_t in enumerate(tangents):
        inputs[i, :fiber_lengths[i], :] = fiber_t

    outputs = np.zeros([n_fibers, max_length, 4])

    print("Starting iteration ...")

    step = 0
    while step < max_length:
        t0 = time()

        xyz = inputs[:, step, :]
        ijk = xyz2ijk(xyz, snap=True)

        for ii, idx in enumerate(ijk):
            try:
                d[ii, :-1] = dwi[idx[0] - (block_size // 2):idx[0] +
                                 (block_size // 2) + 1,
                                 idx[1] - (block_size // 2):idx[1] +
                                 (block_size // 2) + 1, idx[2] -
                                 (block_size // 2):idx[2] + (block_size // 2) +
                                 1, :].flatten()  # returns copy
            except (IndexError, ValueError):
                pass

        d[:, -1] = np.linalg.norm(d[:, :-1], axis=1) + 10**-2

        d[:, :-1] /= d[:, -1].reshape(-1, 1)

        if step == 0:
            vin = -inputs[:, step + 1, :]
            vout = -inputs[:, step, :]
        else:
            vin = inputs[:, step - 1, :]
            vout = inputs[:, step, :]

        model_inputs = np.hstack([vin, d])
        chunk = 2**15  # 32768
        n_chunks = np.ceil(n_fibers / chunk).astype(int)
        for c in range(n_chunks):

            fvm_pred, kappa_pred = model(model_inputs[c * chunk:(c + 1) *
                                                      chunk])

            log1p_kappa_pred = np.log1p(kappa_pred)

            log_prob_pred = fvm_pred.log_prob(vout[c * chunk:(c + 1) * chunk])

            log_prob_map_pred = fvm_pred._log_normalization() + kappa_pred

            outputs[c * chunk:(c + 1) * chunk, step, 0] = kappa_pred
            outputs[c * chunk:(c + 1) * chunk, step, 1] = log1p_kappa_pred
            outputs[c * chunk:(c + 1) * chunk, step, 2] = log_prob_pred
            outputs[c * chunk:(c + 1) * chunk, step, 3] = log_prob_map_pred

        print("Step {:3d}/{:3d}, ETA: {:4.0f} min".format(
            step, max_length, (max_length - step) * (time() - t0) / 60),
              end="\r")

        step += 1

    if gpu_queue is not None:
        gpu_queue.put(gpu_idx)

    kappa = [
        outputs[i, :fiber_lengths[i], 0].reshape(-1, 1)
        for i in range(n_fibers)
    ]
    log1p_kappa = [
        outputs[i, :fiber_lengths[i], 1].reshape(-1, 1)
        for i in range(n_fibers)
    ]
    log_prob = [
        outputs[i, :fiber_lengths[i], 2].reshape(-1, 1)
        for i in range(n_fibers)
    ]
    log_prob_map = [
        outputs[i, :fiber_lengths[i], 3].reshape(-1, 1)
        for i in range(n_fibers)
    ]

    log_prob_sum = [
        np.ones_like(log_prob[i]) * (log_prob[i].sum() / log_prob_map[i].sum())
        for i in range(n_fibers)
    ]
    log_prob_ratio = [
        np.ones_like(log_prob[i]) * (log_prob[i] - log_prob_map[i]).mean()
        for i in range(n_fibers)
    ]

    other_data = {}
    for key in list(trk_file.tractogram.data_per_point.keys()):
        if key not in [
                "kappa", "log1p_kappa", "log_prob", "log_prob_map",
                "log_prob_sum", "log_prob_ratio"
        ]:
            other_data[key] = trk_file.tractogram.data_per_point[key]

    data_per_point = PerArraySequenceDict(n_rows=n_pts,
                                          kappa=kappa,
                                          log_prob=log_prob,
                                          log_prob_sum=log_prob_sum,
                                          log_prob_ratio=log_prob_ratio,
                                          **other_data)
    tractogram = Tractogram(streamlines=tractogram.streamlines,
                            data_per_point=data_per_point,
                            affine_to_rasmm=np.eye(4))
    out_dir = os.path.join(os.path.dirname(config["dwi_path"]),
                           "marked_fibers", timestamp())
    os.makedirs(out_dir, exist_ok=True)

    marked_path = os.path.join(out_dir, "marked.trk")
    TrkFile(tractogram, trk_file.header).save(marked_path)

    config["out_dir"] = out_dir

    configs.save(config)
Exemplo n.º 13
0
def merge_trks(trk_dir, keep, weighted, out_dir):
    """
    WARNING: Alignment between trk files is not checked, but assumed the same!
    """
    bundles = []
    for i, trk_path in enumerate(glob.glob(os.path.join(trk_dir, "*.trk"))):
        print("Loading {:.<20}".format(os.path.basename(trk_path)), end="\r")
        trk_file = nib.streamlines.load(trk_path)
        bundles.append(trk_file.tractogram)
        if i == 0:
            header = trk_file.header

    n_fibers = sum([len(b.streamlines) for b in bundles])
    n_bundles = len(bundles)

    print("Loaded {} fibers from {} bundles.".format(n_fibers, n_bundles))

    merged_bundles = bundles[0].copy()
    for b in bundles[1:]:
        merged_bundles.extend(b)

    if keep < 1:
        if weighted:
            p = np.zeros(n_fibers)
            offset=0
            for b in bundles:
                l = len(b.streamlines)
                p[offset:offset+l] = 1 / (l * n_bundles)
                offset += l
        else:
            p = np.ones(n_fibers) / n_fibers

        keep_n = int(keep * n_fibers)
        print("Subsampling {} fibers".format(keep_n))

        np.random.seed(42)
        subsample = np.random.choice(
            merged_bundles.streamlines,
            size=keep_n,
            replace=False,
            p=p)

        tractogram = Tractogram(
                streamlines=subsample,
                affine_to_rasmm=np.eye(4)
            )
    else:
        tractogram = merged_bundles

    if out_dir is None:
        out_dir = os.path.dirname(trk_dir)
        out_dir = os.path.join(out_dir, "merged_tracts")

    os.makedirs(out_dir, exist_ok=True)

    if weighted:
        save_path = os.path.join(out_dir,
            "merged_W{:04d}.trk".format(int(1000*args.keep)))
    else:
        save_path = os.path.join(out_dir,
            "merged_{:04d}.trk".format(int(1000*args.keep)))

    print("Saving {}".format(save_path))

    TrkFile(tractogram, header).save(save_path)
def main():
    parser = build_argparser()
    args = parser.parse_args()

    signal = nib.load(args.signal)
    data = signal.get_data()

    # Compute matrix that brings streamlines back to diffusion voxel space.
    rasmm2vox_affine = np.linalg.inv(signal.affine)

    # Retrieve data.
    with Timer("Retrieving data"):
        print("Loading {}".format(args.filename))

        # Load streamlines (already in RASmm space)
        tfile = nib.streamlines.load(args.filename)
        tfile.tractogram.apply_affine(rasmm2vox_affine)

        # tfile.tractogram.apply_affine(rasmm2vox_affine)
        tractogram = Tractogram(streamlines=tfile.streamlines,
                                affine_to_rasmm=signal.affine)

    with Timer("Filtering streamlines"):

        # Get volume bounds
        x_max = data.shape[0] - 0.5
        y_max = data.shape[1] - 0.5
        z_max = data.shape[2] - 0.5

        mask = np.ones((len(tractogram), )).astype(bool)

        for i, s in enumerate(tractogram.streamlines):

            # Identify streamlines out of bounds
            oob_test = np.logical_or.reduce((
                s[:, 0] < -0.5,
                s[:, 0] >= x_max,  # Out of bounds on axis X
                s[:, 1] < -0.5,
                s[:, 1] >= y_max,  # Out of bounds on axis Y
                s[:, 2] < -0.5,
                s[:, 2] >= z_max))  # Out of bounds on axis Z

            if np.any(oob_test):
                mask[i] = False

        tractogram_filtered = tractogram[mask]
        tractogram_removed = tractogram[np.logical_not(mask)]

        print("Kept {} streamlines and removed {} streamlines".format(
            len(tractogram_filtered), len(tractogram_removed)))

    with Timer("Saving filtered and removed streamlines"):
        base_filename = args.out_prefix
        if args.out_prefix is None:
            base_filename = args.filename[:-4]

        tractogram_filtered_filename = "{}_filtered.tck".format(base_filename)
        tractogram_removed_filename = "{}_removed.tck".format(base_filename)

        # Save streamlines
        nib.streamlines.save(tractogram_filtered, tractogram_filtered_filename)
        nib.streamlines.save(tractogram_removed, tractogram_removed_filename)
Exemplo n.º 15
0
def run_rf_inference(config=None, gpu_queue=None):
    """"""
    try:
        gpu_idx = maybe_get_a_gpu() if gpu_queue is None else gpu_queue.get()
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu_idx
    except Exception as e:
        print(str(e))

    print(
        "Loading DWI...")  ####################################################

    dwi_img = nib.load(config['dwi_path'])
    dwi_img = nib.funcs.as_closest_canonical(dwi_img)
    dwi_aff = dwi_img.affine
    dwi_affi = np.linalg.inv(dwi_aff)
    dwi = dwi_img.get_data()

    def xyz2ijk(coords, snap=False):
        ijk = (coords.T).copy()
        dwi_affi.dot(ijk, out=ijk)
        if snap:
            return np.round(ijk, out=ijk).astype(int, copy=False).T
        else:
            return ijk.T

    with open(os.path.join(config['model_dir'], 'model'), 'rb') as f:
        model = pickle.load(f)

    train_config_file = os.path.join(config['model_dir'], 'config.yml')
    bvec_path = configs.load(train_config_file, 'bvecs')
    _, bvecs = read_bvals_bvecs(None, bvec_path)

    terminator = Terminator(config['term_path'], config['thresh'])

    prior = Prior(config['prior_path'])

    print(
        "Initializing Fibers...")  ############################################

    seed_file = nib.streamlines.load(config['seed_path'])
    xyz = seed_file.tractogram.streamlines.data
    n_seeds = 2 * len(xyz)
    xyz = np.vstack([xyz, xyz])  # Duplicate seeds for both directions
    xyz = np.hstack([xyz, np.ones([n_seeds, 1])])  # add affine dimension
    xyz = xyz.reshape(-1, 1, 4)  # (fiber, segment, coord)

    fiber_idx = np.hstack([
        np.arange(n_seeds // 2, dtype="int32"),
        np.arange(n_seeds // 2, dtype="int32")
    ])
    fibers = [[] for _ in range(n_seeds // 2)]

    print(
        "Start Iteration...")  ################################################

    input_shape = model.n_features_
    block_size = int(np.cbrt(input_shape / dwi.shape[-1]))

    d = np.zeros([n_seeds, dwi.shape[-1] * block_size**3])
    dnorm = np.zeros([n_seeds, 1])
    vout = np.zeros([n_seeds, 3])
    for i in range(config['max_steps']):
        t0 = time()

        # Get coords of latest segement for each fiber
        ijk = xyz2ijk(xyz[:, -1, :], snap=True)

        n_ongoing = len(ijk)

        for ii, idx in enumerate(ijk):
            d[ii] = dwi[idx[0] - (block_size // 2):idx[0] + (block_size // 2) +
                        1, idx[1] - (block_size // 2):idx[1] +
                        (block_size // 2) + 1,
                        idx[2] - (block_size // 2):idx[2] + (block_size // 2) +
                        1, :].flatten()  # returns copy
            dnorm[ii] = np.linalg.norm(d[ii])
            d[ii] /= dnorm[ii]

        if i == 0:
            inputs = np.hstack(
                [prior(xyz[:, 0, :]), d[:n_ongoing], dnorm[:n_ongoing]])
        else:
            inputs = np.hstack(
                [vout[:n_ongoing], d[:n_ongoing], dnorm[:n_ongoing]])

        chunk = 2**15  # 32768
        n_chunks = np.ceil(n_ongoing / chunk).astype(int)
        for c in range(n_chunks):

            outputs = model.predict(inputs[c * chunk:(c + 1) * chunk])
            v = bvecs[outputs, ...]
            vout[c * chunk:(c + 1) * chunk] = v

        rout = xyz[:, -1, :3] + config['step_size'] * vout
        rout = np.hstack([rout, np.ones((n_ongoing, 1))]).reshape(-1, 1, 4)

        xyz = np.concatenate([xyz, rout], axis=1)

        terminal_indices = terminator(xyz[:, -1, :])

        for idx in terminal_indices:
            gidx = fiber_idx[idx]
            # Other end not yet added
            if not fibers[gidx]:
                fibers[gidx].append(np.copy(xyz[idx, :, :3]))
            # Other end already added
            else:
                this_end = xyz[idx, :, :3]
                other_end = fibers[gidx][0]
                merged_fiber = np.vstack(
                    [np.flip(this_end[1:], axis=0),
                     other_end])  # stitch ends together
                fibers[gidx] = [merged_fiber]

        xyz = np.delete(xyz, terminal_indices, axis=0)
        vout = np.delete(vout, terminal_indices, axis=0)
        fiber_idx = np.delete(fiber_idx, terminal_indices)

        print(
            "Iter {:4d}/{}, finished {:5d}/{:5d} ({:3.0f}%) of all seeds with"
            " {:6.0f} steps/sec".format(
                (i + 1), config['max_steps'], n_seeds - n_ongoing, n_seeds,
                100 * (1 - n_ongoing / n_seeds), n_ongoing / (time() - t0)),
            end="\r")

        if n_ongoing == 0:
            break

        gc.collect()

    # Include unfinished fibers:

    fibers = [
        fibers[gidx] for gidx in range(len(fibers)) if gidx not in fiber_idx
    ]
    # Save Result

    fibers = [f[0] for f in fibers]

    tractogram = Tractogram(streamlines=ArraySequence(fibers),
                            affine_to_rasmm=np.eye(4))

    timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S")
    out_dir = os.path.join(os.path.dirname(config["dwi_path"]),
                           "predicted_fibers", timestamp)

    configs.deep_update(config, {"out_dir": out_dir})

    os.makedirs(out_dir, exist_ok=True)

    fiber_path = os.path.join(out_dir, timestamp + ".trk")
    print("\nSaving {}".format(fiber_path))
    TrkFile(tractogram, seed_file.header).save(fiber_path)

    config_path = os.path.join(out_dir, "config.yml")
    print("Saving {}".format(config_path))
    with open(config_path, "w") as file:
        yaml.dump(config, file, default_flow_style=False)

    if config["score"]:
        score_on_tm(fiber_path)

    return tractogram
Exemplo n.º 16
0
class StatefulTractogram(object):
    """ Class for stateful representation of collections of streamlines
    Object designed to be identical no matter the file format
    (trk, tck, vtk, fib, dpy). Facilitate transformation between space and
    data manipulation for each streamline / point.
    """
    def __init__(self,
                 streamlines,
                 reference,
                 space,
                 shifted_origin=False,
                 data_per_point=None,
                 data_per_streamline=None):
        """ Create a strict, state-aware, robust tractogram

        Parameters
        ----------
        streamlines : list or ArraySequence
            Streamlines of the tractogram
        reference : Nifti or Trk filename, Nifti1Image or TrkFile,
            Nifti1Header, trk.header (dict) or another Stateful Tractogram
            Reference that provides the spatial attributes.
            Typically a nifti-related object from the native diffusion used for
            streamlines generation
        space : Enum (dipy.io.stateful_tractogram.Space)
            Current space in which the streamlines are (vox, voxmm or rasmm)
            Typically after tracking the space is VOX, after nibabel loading
            the space is RASMM
        shifted_origin : bool
            Information on the position of the origin,
            False is Trackvis standard, default (center of the voxel)
            True is NIFTI standard (corner of the voxel)
        data_per_point : dict
            Dictionary in which each key has X items, each items has Y_i items
            X being the number of streamlines
            Y_i being the number of points on streamlines #i
        data_per_streamline : dict
            Dictionary in which each key has X items
            X being the number of streamlines

        Notes
        -----
        Very important to respect the convention, verify that streamlines
        match the reference and are effectively in the right space.

        Any change to the number of streamlines, data_per_point or
        data_per_streamline requires particular verification.

        In a case of manipulation not allowed by this object, use Nibabel
        directly and be careful.
        """
        if data_per_point is None:
            data_per_point = {}

        if data_per_streamline is None:
            data_per_streamline = {}

        if isinstance(streamlines, Streamlines):
            streamlines = streamlines.copy()
        self._tractogram = Tractogram(streamlines,
                                      data_per_point=data_per_point,
                                      data_per_streamline=data_per_streamline)

        space_attributes = get_reference_info(reference)
        if space_attributes is None:
            raise TypeError('Reference MUST be one of the following:\n' +
                            'Nifti or Trk filename, Nifti1Image or TrkFile, ' +
                            'Nifti1Header or trk.header (dict)')

        (self._affine, self._dimensions, self._voxel_sizes,
         self._voxel_order) = space_attributes
        self._inv_affine = np.linalg.inv(self._affine)

        if space not in Space:
            raise ValueError('Space MUST be from Space enum, e.g Space.VOX')
        self._space = space

        if not isinstance(shifted_origin, bool):
            raise TypeError('shifted_origin MUST be a boolean')
        self._shifted_origin = shifted_origin
        logging.debug(self)

    def __str__(self):
        """ Generate the string for printing """
        text = 'Affine: \n{}'.format(
            np.array2string(self._affine,
                            formatter={'float_kind': lambda x: "%.6f" % x}))
        text += '\ndimensions: {}'.format(np.array2string(self._dimensions))
        text += '\nvoxel_sizes: {}'.format(
            np.array2string(self._voxel_sizes,
                            formatter={'float_kind': lambda x: "%.2f" % x}))
        text += '\nvoxel_order: {}'.format(self._voxel_order)

        text += '\nstreamline_count: {}'.format(self._get_streamline_count())
        text += '\npoint_count: {}'.format(self._get_point_count())
        text += '\ndata_per_streamline keys: {}'.format(
            self.data_per_point.keys())
        text += '\ndata_per_point keys: {}'.format(
            self.data_per_streamline.keys())

        return text

    def __len__(self):
        """ Define the length of the object """
        return self._get_streamline_count()

    @property
    def space_attributes(self):
        """ Getter for spatial attribute """
        return self._affine, self._dimensions, self._voxel_sizes, \
            self._voxel_order

    @property
    def space(self):
        """ Getter for the current space """
        return self._space

    @property
    def affine(self):
        """ Getter for the reference affine """
        return self._affine

    @property
    def dimensions(self):
        """ Getter for the reference dimensions """
        return self._dimensions

    @property
    def voxel_sizes(self):
        """ Getter for the reference voxel sizes """
        return self._voxel_sizes

    @property
    def voxel_order(self):
        """ Getter for the reference voxel order """
        return self._voxel_order

    @property
    def shifted_origin(self):
        """ Getter for shift """
        return self._shifted_origin

    @property
    def streamlines(self):
        """ Partially safe getter for streamlines """
        return self._tractogram.streamlines

    def get_streamlines_copy(self):
        """ Safe getter for streamlines (for slicing) """
        return self._tractogram.streamlines.copy()

    @streamlines.setter
    def streamlines(self, streamlines):
        """ Modify streamlines. Creating a new object would be less risky.

        Parameters
        ----------
        streamlines : list or ArraySequence (list and deepcopy recommanded)
            Streamlines of the tractogram
        """
        if isinstance(streamlines, Streamlines):
            streamlines = streamlines.copy()
        self._tractogram._streamlines = Streamlines(streamlines)
        self.data_per_point = self.data_per_point
        self.data_per_streamline = self.data_per_streamline
        logging.warning('Streamlines has been modified')

    @property
    def data_per_point(self):
        """ Getter for data_per_point """
        return self._tractogram.data_per_point

    @data_per_point.setter
    def data_per_point(self, data):
        """ Modify point data . Creating a new object would be less risky.

        Parameters
        ----------
        data : dict
            Dictionary in which each key has X items, each items has Y_i items
            X being the number of streamlines
            Y_i being the number of points on streamlines #i
        """
        self._tractogram.data_per_point = data
        logging.warning('Data_per_point has been modified')

    @property
    def data_per_streamline(self):
        """ Getter for data_per_streamline """
        return self._tractogram.data_per_streamline

    @data_per_streamline.setter
    def data_per_streamline(self, data):
        """ Modify point data . Creating a new object would be less risky.

        Parameters
        ----------
        data : dict
            Dictionary in which each key has X items, each items has Y_i items
            X being the number of streamlines
        """
        self._tractogram.data_per_streamline = data
        logging.warning('Data_per_streamline has been modified')

    def to_vox(self):
        """ Safe function to transform streamlines and update state """
        if self._space == Space.VOXMM:
            self._voxmm_to_vox()
        elif self._space == Space.RASMM:
            self._rasmm_to_vox()

    def to_voxmm(self):
        """ Safe function to transform streamlines and update state """
        if self._space == Space.VOX:
            self._vox_to_voxmm()
        elif self._space == Space.RASMM:
            self._rasmm_to_voxmm()

    def to_rasmm(self):
        """ Safe function to transform streamlines and update state """
        if self._space == Space.VOX:
            self._vox_to_rasmm()
        elif self._space == Space.VOXMM:
            self._voxmm_to_rasmm()

    def to_space(self, target_space):
        """ Safe function to transform streamlines to a particular space using
        an enum and update state """
        if target_space == Space.VOX:
            self.to_vox()
        elif target_space == Space.VOXMM:
            self.to_voxmm()
        elif target_space == Space.RASMM:
            self.to_rasmm()
        else:
            logging.error('Unsupported target space, please use Enum in '
                          'dipy.io.stateful_tractogram')

    def to_center(self):
        """ Safe function to shift streamlines so the center of voxel is
        the origin """
        if self._shifted_origin:
            self._shift_voxel_origin()

    def to_corner(self):
        """ Safe function to shift streamlines so the corner of voxel is
        the origin """
        if not self._shifted_origin:
            self._shift_voxel_origin()

    def compute_bounding_box(self):
        """ Compute the bounding box of the streamlines in their current state

        Returns
        -------
        output : ndarray
            8 corners of the XYZ aligned box, all zeros if no streamlines
        """
        if self._tractogram.streamlines.data.size > 0:
            bbox_min = np.min(self._tractogram.streamlines.data, axis=0)
            bbox_max = np.max(self._tractogram.streamlines.data, axis=0)
            return np.asarray(list(product(*zip(bbox_min, bbox_max))))

        return np.zeros((8, 3))

    def is_bbox_in_vox_valid(self):
        """ Verify that the bounding box is valid in voxel space.
        Negative coordinates or coordinates above the volume dimensions
        are considered invalid in voxel space.

        Returns
        -------
        output : bool
            Are the streamlines within the volume of the associated reference
        """
        if not self.streamlines:
            return True

        old_space = deepcopy(self.space)
        old_shift = deepcopy(self.shifted_origin)

        # Do to rotation, equivalent of a OBB must be done
        self.to_vox()
        self.to_corner()
        bbox_corners = deepcopy(self.compute_bounding_box())

        is_valid = True
        if np.any(bbox_corners < 0):
            logging.error('Voxel space values lower than 0.0')
            logging.debug(bbox_corners)
            is_valid = False

        if np.any(bbox_corners[:, 0] > self._dimensions[0]) or \
                np.any(bbox_corners[:, 1] > self._dimensions[1]) or \
                np.any(bbox_corners[:, 2] > self._dimensions[2]):
            logging.error('Voxel space values higher than dimensions')
            logging.debug(bbox_corners)
            is_valid = False

        if old_space == Space.RASMM:
            self.to_rasmm()
        elif old_space == Space.VOXMM:
            self.to_voxmm()

        if not old_shift:
            self.to_center()

        return is_valid

    def remove_invalid_streamlines(self):
        """ Remove streamlines with invalid coordinates from the object.
        Will also remove the data_per_point and data_per_streamline.
        Invalid coordinates are any X,Y,Z values above the reference
        dimensions or below zero
        Returns
        -------
        output : tuple
            Tuple of two list, indices_to_remove, indices_to_keep
        """
        if not self.streamlines:
            return

        old_space = deepcopy(self.space)
        old_shift = deepcopy(self.shifted_origin)

        self.to_vox()
        self.to_corner()

        min_condition = np.min(self._tractogram.streamlines.data, axis=1) < 0.0
        max_condition = np.any(
            self._tractogram.streamlines.data > self._dimensions, axis=1)
        ic_offsets_indices = np.where(
            np.logical_or(min_condition, max_condition))[0]

        indices_to_remove = []
        for i in ic_offsets_indices:
            indices_to_remove.append(
                bisect(self._tractogram.streamlines._offsets, i) - 1)

        indices_to_keep = np.setdiff1d(np.arange(len(self._tractogram)),
                                       np.array(indices_to_remove)).astype(int)

        tmp_streamlines = self.streamlines[indices_to_keep]
        tmp_data_per_point = self._tractogram.data_per_point[indices_to_keep]
        tmp_data_per_streamline =\
            self._tractogram.data_per_streamline[indices_to_keep]

        self._tractogram = Tractogram(
            tmp_streamlines.copy(),
            data_per_point=tmp_data_per_point,
            data_per_streamline=tmp_data_per_streamline,
            affine_to_rasmm=np.eye(4))

        if old_space == Space.RASMM:
            self.to_rasmm()
        elif old_space == Space.VOXMM:
            self.to_voxmm()

        if not old_shift:
            self.to_center()

        return indices_to_remove, indices_to_keep

    def _get_streamline_count(self):
        """ Safe getter for the number of streamlines """
        return len(self._tractogram)

    def _get_point_count(self):
        """ Safe getter for the number of streamlines """
        return self._tractogram.streamlines.total_nb_rows

    def _vox_to_voxmm(self):
        """ Unsafe function to transform streamlines """
        if self._space == Space.VOX:
            if self._tractogram.streamlines.data.size > 0:
                self._tractogram.streamlines._data *= np.asarray(
                    self._voxel_sizes)
                self._space = Space.VOXMM
                logging.info('Moved streamlines from vox to voxmm')
        else:
            logging.warning('Wrong initial space for this function')
            return

    def _voxmm_to_vox(self):
        """ Unsafe function to transform streamlines """
        if self._space == Space.VOXMM:
            if self._tractogram.streamlines.data.size > 0:
                self._tractogram.streamlines._data /= np.asarray(
                    self._voxel_sizes)
                self._space = Space.VOX
                logging.info('Moved streamlines from voxmm to vox')
        else:
            logging.warning('Wrong initial space for this function')
            return

    def _vox_to_rasmm(self):
        """ Unsafe function to transform streamlines """
        if self._space == Space.VOX:
            if self._tractogram.streamlines.data.size > 0:
                self._tractogram.apply_affine(self._affine)
                self._space = Space.RASMM
                logging.info('Moved streamlines from vox to rasmm')
        else:
            logging.warning('Wrong initial space for this function')
            return

    def _rasmm_to_vox(self):
        """ Unsafe function to transform streamlines """
        if self._space == Space.RASMM:
            if self._tractogram.streamlines.data.size > 0:
                self._tractogram.apply_affine(self._inv_affine)
                self._space = Space.VOX
                logging.info('Moved streamlines from rasmm to vox')
        else:
            logging.warning('Wrong initial space for this function')
            return

    def _voxmm_to_rasmm(self):
        """ Unsafe function to transform streamlines """
        if self._space == Space.VOXMM:
            if self._tractogram.streamlines.data.size > 0:
                self._tractogram.streamlines._data /= np.asarray(
                    self._voxel_sizes)
                self._tractogram.apply_affine(self._affine)
                self._space = Space.RASMM
                logging.info('Moved streamlines from voxmm to rasmm')
        else:
            logging.warning('Wrong initial space for this function')
            return

    def _rasmm_to_voxmm(self):
        """ Unsafe function to transform streamlines """
        if self._space == Space.RASMM:
            if self._tractogram.streamlines.data.size > 0:
                self._tractogram.apply_affine(self._inv_affine)
                self._tractogram.streamlines._data *= np.asarray(
                    self._voxel_sizes)
                self._space = Space.VOXMM
                logging.info('Moved streamlines from rasmm to voxmm')
        else:
            logging.warning('Wrong initial space for this function')
            return

    def _shift_voxel_origin(self):
        """ Unsafe function to switch the origin from center to corner
        and vice versa """
        if not self.streamlines:
            return

        shift = np.asarray([0.5, 0.5, 0.5])
        if self._space == Space.VOXMM:
            shift = shift * self._voxel_sizes
        elif self._space == Space.RASMM:
            tmp_affine = np.eye(4)
            tmp_affine[0:3, 0:3] = self._affine[0:3, 0:3]
            shift = apply_affine(tmp_affine, shift)
        if self._shifted_origin:
            shift *= -1

        self._tractogram.streamlines._data += shift
        if not self._shifted_origin:
            logging.info('Origin moved to the corner of voxel')
        else:
            logging.info('Origin moved to the center of voxel')

        self._shifted_origin = not self._shifted_origin
Exemplo n.º 17
0
def save_tractogram(sft, filename, bbox_valid_check=True):
    """ Save the stateful tractogram in any format (trk, tck, vtk, fib, dpy)

    Parameters
    ----------
    sft : StatefulTractogram
        The stateful tractogram to save
    filename : string
        Filename with valid extension
    bbox_valid_check : bool
        Verification for negative voxel coordinates or values above the
        volume dimensions. Default is True, to enforce valid file.

    Returns
    -------
    output : bool
        True if the saving operation was successful
    """

    _, extension = os.path.splitext(filename)
    if extension not in ['.trk', '.tck', '.vtk', '.fib', '.dpy']:
        raise TypeError('Output filename is not one of the supported format')

    if bbox_valid_check and not sft.is_bbox_in_vox_valid():
        raise ValueError('Bounding box is not valid in voxel space, cannot ' +
                         'load a valid file if some coordinates are ' +
                         'invalid. Please use the function ' +
                         'remove_invalid_streamlines to discard invalid ' +
                         'streamlines or set bbox_valid_check to False')

    old_space = deepcopy(sft.space)
    old_shift = deepcopy(sft.shifted_origin)

    sft.to_rasmm()
    sft.to_center()

    timer = time.time()
    if extension in ['.trk', '.tck']:
        tractogram_type = detect_format(filename)
        header = create_tractogram_header(tractogram_type,
                                          *sft.space_attributes)
        new_tractogram = Tractogram(sft.streamlines, affine_to_rasmm=np.eye(4))

        if extension == '.trk':
            new_tractogram.data_per_point = sft.data_per_point
            new_tractogram.data_per_streamline = sft.data_per_streamline

        fileobj = tractogram_type(new_tractogram, header=header)
        nib.streamlines.save(fileobj, filename)

    elif extension in ['.vtk', '.fib']:
        save_vtk_streamlines(sft.streamlines, filename, binary=True)
    elif extension in ['.dpy']:
        dpy_obj = Dpy(filename, mode='w')
        dpy_obj.write_tracks(sft.streamlines)
        dpy_obj.close()

    logging.debug('Save %s with %s streamlines in %s seconds', filename,
                  len(sft), round(time.time() - timer, 3))

    if old_space == Space.VOX:
        sft.to_vox()
    elif old_space == Space.VOXMM:
        sft.to_voxmm()

    if old_shift:
        sft.to_corner()

    return True
Exemplo n.º 18
0
def resample_tractogram(tractogram,
                        npts,
                        smoothing,
                        min_length=0,
                        max_length=1000):

    streamlines = tractogram.streamlines

    position = ArraySequence()
    tangent = ArraySequence()
    rows = 0

    def max_dist_from_mean(path):
        return np.linalg.norm(path - np.mean(path, axis=0, keepdims=True),
                              axis=1).max()

    n_fails = 0
    n_length = 0
    for i, f in enumerate(streamlines):

        flen = np.linalg.norm(f[1:] - f[:-1], axis=1).sum()
        if (flen < min_length) or (flen > max_length):
            n_length += 1
            continue

        r, t, cnt = fiber_geometry(f, npts=npts, smoothing=smoothing)

        if max_dist_from_mean(r) > 1.2 * max_dist_from_mean(f):
            n_fails += 1
            continue

        position.append(r, cache_build=True)
        tangent.append(t, cache_build=True)
        rows += cnt

        print("Finished {:3.0f}%".format(100 * (i + 1) / len(streamlines)),
              end="\r")

    if n_fails > 0:
        print("Failed to resample {} out of {} ".format(
            n_fails, len(streamlines)) + "fibers, they were not included.")

    if n_length > 0:
        print("{} out of {} ".format(n_length, len(streamlines)) +
              "fibers excluded by length.")

    position.finalize_append()
    tangent.finalize_append()

    other_data = {}
    if npts == "same":
        for key in list(tractogram.data_per_point.keys()):
            if key != "t":
                other_data[key] = tractogram.data_per_point[key]

    data_per_point = PerArraySequenceDict(n_rows=rows, t=tangent, **other_data)

    return Tractogram(
        streamlines=position,
        data_per_point=data_per_point,
        affine_to_rasmm=np.eye(
            4)  # Fiber coordinates are already in rasmm space!
    )