Пример #1
0
def bundle_min_distance(t, static, moving):
    """ MDF-based pairwise distance optimization function (MIN)

    We minimize the distance between moving streamlines as they align
    with the static streamlines.

    Parameters
    -----------
    t : ndarray
        t is a vector of of affine transformation parameters with
        size at least 6.
        If size is 6, t is interpreted as translation + rotation.
        If size is 7, t is interpreted as translation + rotation +
        isotropic scaling.
        If size is 12, t is interpreted as translation + rotation +
        scaling + shearing.

    static : list
        Static streamlines

    moving : list
        Moving streamlines.

    Returns
    -------
    cost: float

    """
    aff = compose_matrix44(t)
    moving = transform_streamlines(moving, aff)
    d01 = distance_matrix_mdf(static, moving)

    rows, cols = d01.shape
    return 0.25 * (np.sum(np.min(d01, axis=0)) / float(cols) +
                   np.sum(np.min(d01, axis=1)) / float(rows)) ** 2
Пример #2
0
def bundle_sum_distance(t, static, moving, num_threads=None):
    """ MDF distance optimization function (SUM)

    We minimize the distance between moving streamlines as they align
    with the static streamlines.

    Parameters
    -----------
    t : ndarray
        t is a vector of of affine transformation parameters with
        size at least 6.
        If size is 6, t is interpreted as translation + rotation.
        If size is 7, t is interpreted as translation + rotation +
        isotropic scaling.
        If size is 12, t is interpreted as translation + rotation +
        scaling + shearing.

    static : list
        Static streamlines

    moving : list
        Moving streamlines. These will be transform to align with
        the static streamlines

    Returns
    -------
    cost: float

    """

    aff = compose_matrix44(t)
    moving = transform_streamlines(moving, aff)
    d01 = distance_matrix_mdf(static, moving)
    return np.sum(d01)
Пример #3
0
def test_cascade_of_optimizations():

    cingulum_bundles = two_cingulum_bundles()

    cb1 = cingulum_bundles[0]
    cb1 = set_number_of_points(cb1, 20)

    test_x0 = np.array([10, 4, 3, 0, 20, 10, 1.5, 1.5, 1.5, 0., 0.2, 0])

    cb2 = transform_streamlines(cingulum_bundles[0],
                                compose_matrix44(test_x0))
    cb2 = set_number_of_points(cb2, 20)

    print('first rigid')
    slr = StreamlineLinearRegistration(x0=6)
    slm = slr.optimize(cb1, cb2)

    print('then similarity')
    slr2 = StreamlineLinearRegistration(x0=7)
    slm2 = slr2.optimize(cb1, cb2, slm.matrix)

    print('then affine')
    slr3 = StreamlineLinearRegistration(x0=12, options={'maxiter': 50})
    slm3 = slr3.optimize(cb1, cb2, slm2.matrix)

    assert_(slm2.fopt < slm.fopt)
    assert_(slm3.fopt < slm2.fopt)
Пример #4
0
def test_rigid_real_bundles():

    bundle_initial = fornix_streamlines()[:20]
    bundle, shift = center_streamlines(bundle_initial)

    mat = compose_matrix44([0, 0, 20, 45., 0, 0])

    bundle2 = transform_streamlines(bundle, mat)

    bundle_sum_distance = BundleSumDistanceMatrixMetric()
    srr = StreamlineLinearRegistration(bundle_sum_distance,
                                       x0=np.zeros(6),
                                       method='Powell')
    new_bundle2 = srr.optimize(bundle, bundle2).transform(bundle2)

    evaluate_convergence(bundle, new_bundle2)

    bundle_min_distance = BundleMinDistanceMatrixMetric()
    srr = StreamlineLinearRegistration(bundle_min_distance,
                                       x0=np.zeros(6),
                                       method='Powell')
    new_bundle2 = srr.optimize(bundle, bundle2).transform(bundle2)

    evaluate_convergence(bundle, new_bundle2)

    assert_raises(ValueError, StreamlineLinearRegistration, method='Whatever')
Пример #5
0
def test_stream_rigid():

    static = fornix_streamlines()[:20]
    moving = fornix_streamlines()[20:40]
    static_center, shift = center_streamlines(static)

    mat = compose_matrix44([0, 0, 0, 0, 40, 0])
    moving = transform_streamlines(moving, mat)

    srr = StreamlineLinearRegistration()
    sr_params = srr.optimize(static, moving)
    moved = transform_streamlines(moving, sr_params.matrix)

    srr = StreamlineLinearRegistration(verbose=True)
    srm = srr.optimize(static, moving)
    moved2 = transform_streamlines(moving, srm.matrix)
    moved3 = srm.transform(moving)

    assert_array_almost_equal(moved[0], moved2[0], decimal=3)
    assert_array_almost_equal(moved2[0], moved3[0], decimal=3)
Пример #6
0
def test_center_and_transform():
    A = np.array([[1, 2, 3], [1, 2, 3.]])
    streamlines = [A for i in range(10)]
    streamlines2, center = center_streamlines(streamlines)
    B = np.zeros((2, 3))
    assert_array_equal(streamlines2[0], B)
    assert_array_equal(center, A[0])

    affine = np.eye(4)
    affine[0, 0] = 2
    affine[:3, -1] = - np.array([2, 1, 1]) * center
    streamlines3 = transform_streamlines(streamlines, affine)
    assert_array_equal(streamlines3[0], B)
Пример #7
0
def test_deform_streamlines():
    # Create Random deformation field
    deformation_field = np.random.randn(200, 200, 200, 3)
    # Specify stream2grid and grid2world
    stream2grid = np.array([[np.random.randn(1)[0], 0, 0, 0],
                            [0, np.random.randn(1)[0], 0, 0],
                            [0, 0, np.random.randn(1)[0], 0],
                            [0, 0, 0, 1]])
    grid2world = np.array([[np.random.randn(1)[0], 0, 0, 0],
                           [0, np.random.randn(1)[0], 0, 0],
                           [0, 0, np.random.randn(1)[0], 0],
                           [0, 0, 0, 1]])
    stream2world = np.dot(stream2grid, grid2world)

    # Deform streamlines (let two grid spaces be the same for simplicity)
    new_streamlines = deform_streamlines(streamlines,
                                         deformation_field,
                                         stream2grid,
                                         grid2world,
                                         stream2grid,
                                         grid2world)

    # Interpolate displacements onto original streamlines
    streamlines_in_grid = transform_streamlines(streamlines, stream2grid)
    disps = values_from_volume(deformation_field, streamlines_in_grid)

    # Put new_streamlines into world space
    new_streamlines_world = transform_streamlines(new_streamlines,
                                                  stream2world)

    # Subtract disps from new_streamlines in world space
    orig_streamlines_world = list(np.subtract(new_streamlines_world, disps))

    # Put orig_streamlines_world into voxmm
    orig_streamlines = transform_streamlines(orig_streamlines_world,
                                             np.linalg.inv(stream2world))
    # All close because of floating pt inprecision
    for o, s in zip(orig_streamlines, streamlines):
        assert_allclose(s, o, rtol=1e-10, atol=0)
Пример #8
0
def voxel2streamline(streamline, transformed=False, affine=None,
                     unique_idx=None):
    """
    Maps voxels to streamlines and streamlines to voxels, for setting up
    the LiFE equations matrix

    Parameters
    ----------
    streamline : list
        A collection of streamlines, each n by 3, with n being the number of
        nodes in the fiber.

    affine : 4 by 4 array (optional)
       Defines the spatial transformation from streamline to data.
       Default: np.eye(4)

    transformed : bool (optional)
        Whether the streamlines have been already transformed (in which case
        they don't need to be transformed in here).

    unique_idx : array (optional).
       The unique indices in the streamlines

    Returns
    -------
    v2f, v2fn : tuple of arrays

    The first array in the tuple answers the question: Given a voxel (from
    the unique indices in this model), which fibers pass through it? Shape:
    (n_voxels, n_fibers).

    The second answers the question: Given a voxel, for each fiber, which
    nodes are in that voxel? Shape: (n_voxels, max(n_nodes per fiber)).
    """
    if transformed:
        transformed_streamline = streamline
    else:
        if affine is None:
            affine = np.eye(4)
        transformed_streamline = transform_streamlines(streamline, affine)

    if unique_idx is None:
        all_coords = np.concatenate(transformed_streamline)
        unique_idx = unique_rows(all_coords.astype(int))
    else:
        unique_idx = unique_idx

    return _voxel2streamline(transformed_streamline, unique_idx)
Пример #9
0
    def transform(self, moving):
        """ Transform moving streamlines to the static.

        Parameters
        ----------
        moving : streamlines

        Returns
        -------
        moved : streamlines

        Notes
        -----

        All this does is apply ``self.matrix`` to the input streamlines.
        """

        return transform_streamlines(moving, self.matrix)
Пример #10
0
def test_rigid_parallel_lines():

    bundle_initial = simulated_bundle()
    bundle, shift = center_streamlines(bundle_initial)
    mat = compose_matrix44([20, 0, 10, 0, 40, 0])

    bundle2 = transform_streamlines(bundle, mat)

    bundle_sum_distance = BundleSumDistanceMatrixMetric()
    options = {'maxcor': 100, 'ftol': 1e-9, 'gtol': 1e-16, 'eps': 1e-3}
    srr = StreamlineLinearRegistration(metric=bundle_sum_distance,
                                       x0=np.zeros(6),
                                       method='L-BFGS-B',
                                       bounds=None,
                                       options=options)

    new_bundle2 = srr.optimize(bundle, bundle2).transform(bundle2)
    evaluate_convergence(bundle, new_bundle2)
Пример #11
0
def test_affine_real_bundles():

    bundle_initial = fornix_streamlines()
    bundle_initial, shift = center_streamlines(bundle_initial)
    bundle = bundle_initial[:20]
    xgold = [0, 4, 2, 0, 10, 10, 1.2, 1.1, 1., 0., 0.2, 0.]
    mat = compose_matrix44(xgold)
    bundle2 = transform_streamlines(bundle_initial[:20], mat)

    x0 = np.array([0, 0, 0, 0, 0, 0, 1., 1., 1., 0, 0, 0])

    x = 25

    bounds = [(-x, x), (-x, x), (-x, x),
              (-x, x), (-x, x), (-x, x),
              (0.1, 1.5), (0.1, 1.5), (0.1, 1.5),
              (-1, 1), (-1, 1), (-1, 1)]

    options = {'maxcor': 10, 'ftol': 1e-7, 'gtol': 1e-5, 'eps': 1e-8}

    metric = BundleMinDistanceMatrixMetric()

    slr = StreamlineLinearRegistration(metric=metric,
                                       x0=x0,
                                       method='L-BFGS-B',
                                       bounds=bounds,
                                       verbose=True,
                                       options=options)
    slm = slr.optimize(bundle, bundle2)

    new_bundle2 = slm.transform(bundle2)

    slr2 = StreamlineLinearRegistration(metric=metric,
                                        x0=x0,
                                        method='Powell',
                                        bounds=None,
                                        verbose=True,
                                        options=None)

    slm2 = slr2.optimize(bundle, new_bundle2)

    new_bundle2 = slm2.transform(new_bundle2)

    evaluate_convergence(bundle, new_bundle2)
Пример #12
0
def test_rigid_partial_real_bundles():

    static = fornix_streamlines()[:20]
    moving = fornix_streamlines()[20:40]
    static_center, shift = center_streamlines(static)
    moving_center, shift2 = center_streamlines(moving)

    print(shift2)
    mat = compose_matrix(translate=np.array([0, 0, 0.]),
                         angles=np.deg2rad([40, 0, 0.]))
    moved = transform_streamlines(moving_center, mat)

    srr = StreamlineLinearRegistration()

    srm = srr.optimize(static_center, moved)
    print(srm.fopt)
    print(srm.iterations)
    print(srm.funcs)

    moving_back = srm.transform(moved)
    print(srm.matrix)

    static_center = set_number_of_points(static_center, 100)
    moving_center = set_number_of_points(moving_back, 100)

    vol = np.zeros((100, 100, 100))
    spts = np.concatenate(static_center, axis=0)
    spts = np.round(spts).astype(np.int) + np.array([50, 50, 50])

    mpts = np.concatenate(moving_center, axis=0)
    mpts = np.round(mpts).astype(np.int) + np.array([50, 50, 50])

    for index in spts:
        i, j, k = index
        vol[i, j, k] = 1

    vol2 = np.zeros((100, 100, 100))
    for index in mpts:
        i, j, k = index
        vol2[i, j, k] = 1

    overlap = np.sum(np.logical_and(vol, vol2)) / float(np.sum(vol2))

    assert_equal(overlap * 100 > 40, True)
Пример #13
0
def test_similarity_real_bundles():

    bundle_initial = fornix_streamlines()
    bundle_initial, shift = center_streamlines(bundle_initial)
    bundle = bundle_initial[:20]
    xgold = [0, 0, 10, 0, 0, 0, 1.5]
    mat = compose_matrix44(xgold)
    bundle2 = transform_streamlines(bundle_initial[:20], mat)

    metric = BundleMinDistanceMatrixMetric()
    x0 = np.array([0, 0, 0, 0, 0, 0, 1], 'f8')

    slr = StreamlineLinearRegistration(metric=metric,
                                       x0=x0,
                                       method='Powell',
                                       bounds=None,
                                       verbose=False)

    slm = slr.optimize(bundle, bundle2)
    new_bundle2 = slm.transform(bundle2)
    evaluate_convergence(bundle, new_bundle2)
Пример #14
0
def plot_bundles_with_metric(bundle_path,
                             endings_path,
                             brain_mask_path,
                             bundle,
                             metrics,
                             output_path,
                             tracking_format="trk_legacy",
                             show_color_bar=True):
    import seaborn as sns  # import in function to avoid error if not installed (this is only needed in this function)
    from dipy.viz import actor, window
    from tractseg.libs import vtk_utils

    def _add_extra_point_to_last_streamline(sl):
        # Coloring broken as soon as all streamlines have same number of points -> why???
        # Add one number to last streamline to make it have a different number
        sl[-1] = np.append(sl[-1], [sl[-1][-1]], axis=0)
        return sl

    # Settings
    NR_SEGMENTS = 100
    ANTI_INTERPOL_MULT = 1  # increase number of points to avoid interpolation to blur the colors
    algorithm = "distance_map"  # equal_dist | distance_map | cutting_plane
    # colors = np.array(sns.color_palette("coolwarm", NR_SEGMENTS))  # colormap blue to red (does not fit to colorbar)
    colors = np.array(sns.light_palette(
        "red", NR_SEGMENTS))  # colormap only red, which fits to color_bar
    img_size = (1000, 1000)

    # Tractometry skips first and last element. Therefore we only have 98 instead of 100 elements.
    # Here we duplicate the first and last element to get back to 100 elements
    metrics = list(metrics)
    metrics = np.array([metrics[0]] + metrics + [metrics[-1]])

    metrics_max = metrics.max()
    metrics_min = metrics.min()
    if metrics_max == metrics_min:
        metrics = np.zeros(len(metrics))
    else:
        metrics = img_utils.scale_to_range(
            metrics,
            range=(0, 99))  # range needs to be same as segments in colormap

    orientation = dataset_specific_utils.get_optimal_orientation_for_bundle(
        bundle)

    # Load mask
    beginnings_img = nib.load(endings_path)
    beginnings = beginnings_img.get_data()
    for i in range(1):
        beginnings = binary_dilation(beginnings)

    # Load trackings
    if tracking_format == "trk_legacy":
        streams, hdr = trackvis.read(bundle_path)
        streamlines = [s[0] for s in streams]
    else:
        sl_file = nib.streamlines.load(bundle_path)
        streamlines = sl_file.streamlines
    streamlines = list(
        transform_streamlines(streamlines,
                              np.linalg.inv(beginnings_img.affine)))

    # Reduce streamline count
    streamlines = streamlines[::2]

    # Reorder to make all streamlines have same start region
    streamlines = fiber_utils.add_to_each_streamline(streamlines, 0.5)
    streamlines_new = []
    for idx, sl in enumerate(streamlines):
        startpoint = sl[0]
        # Flip streamline if not in right order
        if beginnings[int(startpoint[0]),
                      int(startpoint[1]),
                      int(startpoint[2])] == 0:
            sl = sl[::-1, :]
        streamlines_new.append(sl)
    streamlines = fiber_utils.add_to_each_streamline(streamlines_new, -0.5)

    if algorithm == "distance_map" or algorithm == "equal_dist":
        streamlines = fiber_utils.resample_fibers(
            streamlines, NR_SEGMENTS * ANTI_INTERPOL_MULT)
    elif algorithm == "cutting_plane":
        streamlines = fiber_utils.resample_to_same_distance(
            streamlines,
            max_nr_points=NR_SEGMENTS,
            ANTI_INTERPOL_MULT=ANTI_INTERPOL_MULT)

    # Cut start and end by percentage
    # streamlines = FiberUtils.resample_fibers(streamlines, NR_SEGMENTS * ANTI_INTERPOL_MULT)
    # remove = int((NR_SEGMENTS * ANTI_INTERPOL_MULT) * 0.15)  # remove X% in beginning and end
    # streamlines = np.array(streamlines)[:, remove:-remove, :]
    # streamlines = list(streamlines)

    if algorithm == "equal_dist":
        segment_idxs = []
        for i in range(len(streamlines)):
            segment_idxs.append(list(range(NR_SEGMENTS * ANTI_INTERPOL_MULT)))
        segment_idxs = np.array(segment_idxs)

    elif algorithm == "distance_map":
        metric = AveragePointwiseEuclideanMetric()
        qb = QuickBundles(threshold=100., metric=metric)
        clusters = qb.cluster(streamlines)
        centroids = Streamlines(clusters.centroids)
        _, segment_idxs = cKDTree(centroids.data, 1,
                                  copy_data=True).query(streamlines, k=1)

    elif algorithm == "cutting_plane":
        streamlines_resamp = fiber_utils.resample_fibers(
            streamlines, NR_SEGMENTS * ANTI_INTERPOL_MULT)
        metric = AveragePointwiseEuclideanMetric()
        qb = QuickBundles(threshold=100., metric=metric)
        clusters = qb.cluster(streamlines_resamp)
        centroid = Streamlines(clusters.centroids)[0]
        # index of the middle cluster
        middle_idx = int(NR_SEGMENTS / 2) * ANTI_INTERPOL_MULT
        middle_point = centroid[middle_idx]
        segment_idxs = fiber_utils.get_idxs_of_closest_points(
            streamlines, middle_point)
        # Align along the middle and assign indices
        segment_idxs_eqlen = []
        for idx, sl in enumerate(streamlines):
            sl_middle_pos = segment_idxs[idx]
            before_elems = sl_middle_pos
            after_elems = len(sl) - sl_middle_pos
            base_idx = 1000  # use higher index to avoid negative numbers for area below middle
            r = range((base_idx - before_elems), (base_idx + after_elems))
            segment_idxs_eqlen.append(r)
        segment_idxs = segment_idxs_eqlen

    # Add extra point otherwise coloring BUG
    streamlines = _add_extra_point_to_last_streamline(streamlines)

    renderer = window.Renderer()
    colors_all = []  # final shape will be [nr_streamlines, nr_points, 3]
    for jdx, sl in enumerate(streamlines):
        colors_sl = []
        for idx, p in enumerate(sl):
            if idx >= len(segment_idxs[jdx]):
                seg_idx = segment_idxs[jdx][idx - 1]
            else:
                seg_idx = segment_idxs[jdx][idx]

            m = metrics[int(seg_idx / ANTI_INTERPOL_MULT)]
            color = colors[int(m)]
            colors_sl.append(color)
        colors_all.append(
            colors_sl
        )  # this can not be converted to numpy array because last element has one more elem

    sl_actor = actor.streamtube(streamlines,
                                colors=colors_all,
                                linewidth=0.2,
                                opacity=1)
    renderer.add(sl_actor)

    # plot brain mask
    mask = nib.load(brain_mask_path).get_data()
    cont_actor = vtk_utils.contour_from_roi_smooth(mask,
                                                   affine=np.eye(4),
                                                   color=[.9, .9, .9],
                                                   opacity=.2,
                                                   smoothing=50)
    renderer.add(cont_actor)

    if show_color_bar:
        lut_cmap = actor.colormap_lookup_table(scale_range=(metrics_min,
                                                            metrics_max),
                                               hue_range=(0.0, 0.0),
                                               saturation_range=(0.0, 1.0))
        renderer.add(actor.scalar_bar(lut_cmap))

    if orientation == "sagittal":
        renderer.set_camera(position=(-242.14, 81.28, 113.61),
                            focal_point=(109.90, 93.18, 50.86),
                            view_up=(0.18, 0.00, 0.98))
    elif orientation == "coronal":
        renderer.set_camera(position=(66.82, 352.47, 132.99),
                            focal_point=(72.17, 89.31, 60.83),
                            view_up=(0.00, -0.26, 0.96))
    elif orientation == "axial":
        pass
    else:
        raise ValueError("Invalid orientation provided")

    # Use this to interatively get new camera angle
    # window.show(renderer, size=img_size, reset_camera=False)
    # print(renderer.get_camera())

    window.record(renderer, out_path=output_path, size=img_size)
Пример #15
0
affine = dix['affine']

"""
Store the cingulum bundle. A bundle is a list of streamlines.
"""

bundle = dix['cg.left']

"""
It happened that this bundle is in world coordinates and therefore we need to
transform it into native image coordinates so that it is in the same coordinate
space as the ``fa`` image.
"""

bundle_native = transform_streamlines(bundle, np.linalg.inv(affine))

"""
Show every streamline with an orientation color
===============================================

This is the default option when you are using ``line`` or ``streamtube``.
"""

renderer = window.Renderer()

stream_actor = actor.line(bundle_native)

renderer.set_camera(position=(-176.42, 118.52, 128.20),
                    focal_point=(113.30, 128.31, 76.56),
                    view_up=(0.18, 0.00, 0.98))
Пример #16
0
def warp_streamlines(sft, deformation_data, source='ants'):
    """ Warp tractogram using a deformation map. Apply warp in-place.
    Support Ants and Dipy deformation map.

    Parameters
    ----------
    streamlines: list or ArraySequence
        Streamlines as loaded by the nibabel API (RASMM)
    transfo: numpy.ndarray
        Transformation matrix to bring streamlines from RASMM to Voxel space
    deformation_data: numpy.ndarray
        4D numpy array containing a 3D displacement vector in each voxel
    source: str
        Source of the deformation map [ants, dipy]
    """
    sft.to_rasmm()
    sft.to_center()
    streamlines = sft.streamlines
    transfo = sft.affine
    if source == 'ants':
        flip = [-1, -1, 1]
    elif source == 'dipy':
        flip = [1, 1, 1]

    # Because of duplication, an iteration over chunks of points is necessary
    # for a big dataset (especially if not compressed)
    streamlines = ArraySequence(streamlines)
    nb_points = len(streamlines._data)
    cur_position = 0
    chunk_size = 1000000
    nb_iteration = int(np.ceil(nb_points / chunk_size))
    inv_transfo = np.linalg.inv(transfo)

    while nb_iteration > 0:
        max_position = min(cur_position + chunk_size, nb_points)
        points = streamlines._data[cur_position:max_position]

        # To access the deformation information, we need to go in voxel space
        # No need for corner shift since we are doing interpolation
        cur_points_vox = np.array(transform_streamlines(points, inv_transfo)).T

        x_def = map_coordinates(deformation_data[..., 0],
                                cur_points_vox.tolist(),
                                order=1)
        y_def = map_coordinates(deformation_data[..., 1],
                                cur_points_vox.tolist(),
                                order=1)
        z_def = map_coordinates(deformation_data[..., 2],
                                cur_points_vox.tolist(),
                                order=1)

        # ITK is in LPS and nibabel is in RAS, a flip is necessary for ANTs
        final_points = np.array(
            [flip[0] * x_def, flip[1] * y_def, flip[2] * z_def])

        # The Ants deformation is relative to world space
        if source == 'ants':
            final_points += np.array(points).T
        # Dipy transformation is relative to vox space
        elif source == 'dipy':
            final_points += cur_points_vox
            transform_streamlines(final_points, transfo, in_place=True)
        streamlines._data[cur_position:max_position] = final_points.T
        cur_position = max_position
        nb_iteration -= 1

        return streamlines
Пример #17
0
    def setup(self, streamline, affine, evals=[0.001, 0, 0], sphere=None):
        """
        Set up the necessary components for the LiFE model: the matrix of
        fiber-contributions to the DWI signal, and the coordinates of voxels
        for which the equations will be solved

        Parameters
        ----------
        streamline : list
            Streamlines, each is an array of shape (n, 3)
        affine : 4 by 4 array
            Mapping from the streamline coordinates to the data
        evals : list (3 items, optional)
            The eigenvalues of the canonical tensor used as a response
            function. Default:[0.001, 0, 0].
        sphere: `dipy.core.Sphere` instance.
            Whether to approximate (and cache) the signal on a discrete
            sphere. This may confer a significant speed-up in setting up the
            problem, but is not as accurate. If `False`, we use the exact
            gradients along the streamlines to calculate the matrix, instead of
            an approximation. Defaults to use the 724-vertex symmetric sphere
            from :mod:`dipy.data`
        """
        if sphere is not False:
            SignalMaker = LifeSignalMaker(self.gtab,
                                          evals=evals,
                                          sphere=sphere)

        if affine is None:
            affine = np.eye(4)
        streamline = transform_streamlines(streamline, affine)
        # Assign some local variables, for shorthand:
        all_coords = np.concatenate(streamline)
        vox_coords = unique_rows(np.round(all_coords).astype(np.intp))
        del all_coords
        # We only consider the diffusion-weighted signals:
        n_bvecs = self.gtab.bvals[~self.gtab.b0s_mask].shape[0]
        v2f, v2fn = voxel2streamline(streamline, transformed=True,
                                     affine=affine, unique_idx=vox_coords)
        # How many fibers in each voxel (this will determine how many
        # components are in the matrix):
        n_unique_f = len(np.hstack(v2f.values()))
        # Preallocate these, which will be used to generate the sparse
        # matrix:
        f_matrix_sig = np.zeros(n_unique_f * n_bvecs, dtype=np.float)
        f_matrix_row = np.zeros(n_unique_f * n_bvecs, dtype=np.intp)
        f_matrix_col = np.zeros(n_unique_f * n_bvecs, dtype=np.intp)

        fiber_signal = []
        for s_idx, s in enumerate(streamline):
            if sphere is not False:
                fiber_signal.append(SignalMaker.streamline_signal(s))
            else:
                fiber_signal.append(streamline_signal(s, self.gtab, evals))

        del streamline
        if sphere is not False:
            del SignalMaker

        keep_ct = 0
        range_bvecs = np.arange(n_bvecs).astype(int)
        # In each voxel:
        for v_idx in range(vox_coords.shape[0]):
            mat_row_idx = (range_bvecs + v_idx * n_bvecs).astype(np.intp)
            # For each fiber in that voxel:
            for f_idx in v2f[v_idx]:
                # For each fiber-voxel combination, store the row/column
                # indices in the pre-allocated linear arrays
                f_matrix_row[keep_ct:keep_ct+n_bvecs] = mat_row_idx
                f_matrix_col[keep_ct:keep_ct+n_bvecs] = f_idx

                vox_fiber_sig = np.zeros(n_bvecs)
                for node_idx in v2fn[f_idx][v_idx]:
                    # Sum the signal from each node of the fiber in that voxel:
                    vox_fiber_sig += fiber_signal[f_idx][node_idx]
                # And add the summed thing into the corresponding rows:
                f_matrix_sig[keep_ct:keep_ct+n_bvecs] += vox_fiber_sig
                keep_ct = keep_ct + n_bvecs

        del v2f, v2fn
        # Allocate the sparse matrix, using the more memory-efficient 'csr'
        # format:
        life_matrix = sps.csr_matrix((f_matrix_sig,
                                     [f_matrix_row, f_matrix_col]))

        return life_matrix, vox_coords
Пример #18
0
def test_bundle_maps():
    scene = window.Scene()
    bundle = fornix_streamlines()
    bundle, _ = center_streamlines(bundle)

    mat = np.array([[1, 0, 0, 100],
                    [0, 1, 0, 100],
                    [0, 0, 1, 100],
                    [0, 0, 0, 1.]])

    bundle = transform_streamlines(bundle, mat)

    # metric = np.random.rand(*(200, 200, 200))
    metric = 100 * np.ones((200, 200, 200))

    # add lower values
    metric[100, :, :] = 100 * 0.5

    # create a nice orange-red colormap
    lut = actor.colormap_lookup_table(scale_range=(0., 100.),
                                      hue_range=(0., 0.1),
                                      saturation_range=(1, 1),
                                      value_range=(1., 1))

    line = actor.line(bundle, metric, linewidth=0.1, lookup_colormap=lut)
    scene.add(line)
    scene.add(actor.scalar_bar(lut, ' '))

    report = window.analyze_scene(scene)

    npt.assert_almost_equal(report.actors, 1)
    # window.show(scene)

    scene.clear()

    nb_points = np.sum([len(b) for b in bundle])
    values = 100 * np.random.rand(nb_points)
    # values[:nb_points/2] = 0

    line = actor.streamtube(bundle, values, linewidth=0.1, lookup_colormap=lut)
    scene.add(line)
    # window.show(scene)

    report = window.analyze_scene(scene)
    npt.assert_equal(report.actors_classnames[0], 'vtkLODActor')

    scene.clear()

    colors = np.random.rand(nb_points, 3)
    # values[:nb_points/2] = 0

    line = actor.line(bundle, colors, linewidth=2)
    scene.add(line)
    # window.show(scene)

    report = window.analyze_scene(scene)
    npt.assert_equal(report.actors_classnames[0], 'vtkLODActor')
    # window.show(scene)

    arr = window.snapshot(scene)
    report2 = window.analyze_snapshot(arr)
    npt.assert_equal(report2.objects, 1)

    # try other input options for colors
    scene.clear()
    actor.line(bundle, (1., 0.5, 0))
    actor.line(bundle, np.arange(len(bundle)))
    actor.line(bundle)
    colors = [np.random.rand(*b.shape) for b in bundle]
    actor.line(bundle, colors=colors)
Пример #19
0
With our current design it is easy to decide in which space you want the
streamlines and slices to appear. The default we have here is to appear in
world coordinates (RAS 1mm).
"""

world_coords = True

"""
If we want to see the objects in native space we need to make sure that all
objects which are currently in world coordinates are transformed back to
native space using the inverse of the affine.
"""

if not world_coords:
    from dipy.tracking.streamline import transform_streamlines
    streamlines = transform_streamlines(streamlines, np.linalg.inv(affine))

"""
Now we create, a ``Renderer`` object and add the streamlines using the ``line``
function and an image plane using the ``slice`` function.
"""

ren = window.Renderer()
stream_actor = actor.line(streamlines)

if not world_coords:
    image_actor = actor.slicer(data, affine=np.eye(4))
else:
    image_actor = actor.slicer(data, affine)

"""
Пример #20
0
def direct_streamline_norm(streams,
                           fa_path,
                           ap_path,
                           dir_path,
                           track_type,
                           conn_model,
                           subnet,
                           node_radius,
                           dens_thresh,
                           ID,
                           roi,
                           min_span_tree,
                           disp_filt,
                           parc,
                           prune,
                           atlas,
                           labels_im_file,
                           parcellation,
                           labels,
                           coords,
                           norm,
                           binary,
                           atlas_t1w,
                           basedir_path,
                           curv_thr_list,
                           step_list,
                           traversal,
                           min_length,
                           t1w_brain,
                           run_dsn=False):
    """
    A Function to perform normalization of streamlines tracked in native
    diffusion space to an MNI-space template.

    Parameters
    ----------
    streams : str
        File path to save streamline array sequence in .trk format.
    fa_path : str
        File path to FA Nifti1Image.
    ap_path : str
        File path to the anisotropic power Nifti1Image.
    dir_path : str
        Path to directory containing subject derivative data for a given
        pynets run.
    track_type : str
        Tracking algorithm used (e.g. 'local' or 'particle').
    conn_model : str
        Connectivity reconstruction method (e.g. 'csa', 'tensor', 'csd').
    subnet : str
        Resting-state subnet based on Yeo-7 and Yeo-17 naming (e.g. 'Default')
        used to filter nodes in the study of brain subgraphs.
    node_radius : int
        Spherical centroid node size in the case that coordinate-based
        centroids are used as ROI's for tracking.
    dens_thresh : bool
        Indicates whether a target graph density is to be used as the basis for
        thresholding.
    ID : str
        A subject id or other unique identifier.
    roi : str
        File path to binarized/boolean region-of-interest Nifti1Image file.
    min_span_tree : bool
        Indicates whether local thresholding from the Minimum Spanning Tree
        should be used.
    disp_filt : bool
        Indicates whether local thresholding using a disparity filter and
        'backbone subnet' should be used.
    parc : bool
        Indicates whether to use parcels instead of coordinates as ROI nodes.
    prune : bool
        Indicates whether to prune final graph of disconnected nodes/isolates.
    atlas : str
        Name of atlas parcellation used.
    labels_im_file : str
        File path to atlas parcellation Nifti1Image aligned to dwi space.
    parcellation : str
        File path to atlas parcellation Nifti1Image in MNI template space.
    labels : list
        List of string labels corresponding to graph nodes.
    coords : list
        List of (x, y, z) tuples corresponding to a coordinate atlas used or
        which represent the center-of-mass of each parcellation node.
    norm : int
        Indicates method of normalizing resulting graph.
    binary : bool
        Indicates whether to binarize resulting graph edges to form an
        unweighted graph.
    atlas_t1w : str
        File path to atlas parcellation Nifti1Image in T1w-conformed space.
    basedir_path : str
        Path to directory to output direct-streamline normalized temp files
        and outputs.
    curv_thr_list : list
        List of integer curvature thresholds used to perform ensemble tracking.
    step_list : list
        List of float step-sizes used to perform ensemble tracking.
    traversal : str
        The statistical approach to tracking. Options are: det (deterministic),
        closest (clos), boot (bootstrapped), and prob (probabilistic).
    min_length : int
        Minimum fiber length threshold in mm to restrict tracking.
    t1w_brain : str
        File path to the T1w Nifti1Image.

    Returns
    -------
    streams_warp : str
        File path to normalized streamline array sequence in .trk format.
    dir_path : str
        Path to directory containing subject derivative data for a given
        pynets run.
    track_type : str
        Tracking algorithm used (e.g. 'local' or 'particle').
    conn_model : str
        Connectivity reconstruction method (e.g. 'csa', 'tensor', 'csd').
    subnet : str
        Resting-state subnet based on Yeo-7 and Yeo-17 naming (e.g. 'Default')
        used to filter nodes in the study of brain subgraphs.
    node_radius : int
        Spherical centroid node size in the case that coordinate-based
        centroids are used as ROI's for tracking.
    dens_thresh : bool
        Indicates whether a target graph density is to be used as the basis for
        thresholding.
    ID : str
        A subject id or other unique identifier.
    roi : str
        File path to binarized/boolean region-of-interest Nifti1Image file.
    min_span_tree : bool
        Indicates whether local thresholding from the Minimum Spanning Tree
        should be used.
    disp_filt : bool
        Indicates whether local thresholding using a disparity filter and
        'backbone subnet' should be used.
    parc : bool
        Indicates whether to use parcels instead of coordinates as ROI nodes.
    prune : bool
        Indicates whether to prune final graph of disconnected nodes/isolates.
    atlas : str
        Name of atlas parcellation used.
    parcellation : str
        File path to atlas parcellation Nifti1Image in MNI template space.
    labels : list
        List of string labels corresponding to graph nodes.
    coords : list
        List of (x, y, z) tuples corresponding to a coordinate atlas used or
        which represent the center-of-mass of each parcellation node.
    norm : int
        Indicates method of normalizing resulting graph.
    binary : bool
        Indicates whether to binarize resulting graph edges to form an
        unweighted graph.
    atlas_for_streams : str
        File path to atlas parcellation Nifti1Image in the same
        morphological space as the streamlines.
    traversal : str
        The statistical approach to tracking. Options are: det
        (deterministic), closest (clos), boot (bootstrapped),
        and prob (probabilistic).
    warped_fa : str
        File path to MNI-space warped FA Nifti1Image.
    min_length : int
        Minimum fiber length threshold in mm to restrict tracking.

    References
    ----------
    .. [1] Greene, C., Cieslak, M., & Grafton, S. T. (2017). Effect of
      different spatial normalization approaches on tractography and structural
      brain subnets. subnet Neuroscience, 1-19.
    """
    import gc
    from dipy.tracking.streamline import transform_streamlines
    from pynets.registration import utils as regutils
    from pynets.plotting.brain import show_template_bundles
    import os.path as op
    from dipy.io.streamline import load_tractogram
    from dipy.tracking._utils import _mapping_to_voxel
    from dipy.tracking.utils import density_map
    from dipy.io.stateful_tractogram import Space, StatefulTractogram, Origin
    from dipy.io.streamline import save_tractogram

    if run_dsn is True:
        dsn_dir = f"{basedir_path}/dmri_reg/DSN"
        if not op.isdir(dsn_dir):
            os.mkdir(dsn_dir)

        namer_dir = f"{dir_path}/tractography"
        if not op.isdir(namer_dir):
            os.mkdir(namer_dir)

        atlas_img = nib.load(labels_im_file)

        # Run SyN and normalize streamlines
        fa_img = nib.load(fa_path)

        atlas_t1w_img = nib.load(atlas_t1w)
        t1w_brain_img = nib.load(t1w_brain)
        brain_mask = np.asarray(t1w_brain_img.dataobj).astype("bool")

        streams_t1w = "%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s" % (
            namer_dir,
            "/streamlines_t1w_",
            "%s" % (subnet + "_" if subnet is not None else ""),
            "%s" %
            (op.basename(roi).split(".")[0] + "_" if roi is not None else ""),
            conn_model,
            "%s" % ("%s%s" % ("_" + str(node_radius), "mm_") if
                    ((node_radius != "parc") and
                     (node_radius is not None)) else "_"),
            "curv",
            str(curv_thr_list).replace(", ", "_"),
            "step",
            str(step_list).replace(", ", "_"),
            "tracktype-",
            track_type,
            "_traversal-",
            traversal,
            "_minlength-",
            min_length,
            ".trk",
        )

        density_t1w = "%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s" % (
            namer_dir,
            "/density_map_t1w_",
            "%s" % (subnet + "_" if subnet is not None else ""),
            "%s" %
            (op.basename(roi).split(".")[0] + "_" if roi is not None else ""),
            conn_model,
            "%s" % ("%s%s" % ("_" + str(node_radius), "mm_") if
                    ((node_radius != "parc") and
                     (node_radius is not None)) else "_"),
            "curv",
            str(curv_thr_list).replace(", ", "_"),
            "step",
            str(step_list).replace(", ", "_"),
            "tracktype-",
            track_type,
            "_traversal-",
            traversal,
            "_minlength-",
            min_length,
            ".nii.gz",
        )

        streams_warp_png = '/tmp/dsn.png'

        # SyN FA->Template
        [mapping, affine_map,
         warped_fa] = regutils.wm_syn(t1w_brain, ap_path, dsn_dir)

        tractogram = load_tractogram(
            streams,
            fa_img,
            to_origin=Origin.NIFTI,
            to_space=Space.VOXMM,
            bbox_valid_check=False,
        )

        fa_img.uncache()
        streamlines = tractogram.streamlines
        warped_fa_img = nib.load(warped_fa)
        warped_fa_affine = warped_fa_img.affine
        warped_fa_shape = warped_fa_img.shape

        streams_in_curr_grid = transform_streamlines(streamlines,
                                                     affine_map.affine_inv)

        streams_final_filt = regutils.warp_streamlines(t1w_brain_img.affine,
                                                       fa_img.affine, mapping,
                                                       warped_fa_img,
                                                       streams_in_curr_grid,
                                                       brain_mask)

        # Remove streamlines with negative voxel indices
        lin_T, offset = _mapping_to_voxel(np.eye(4))
        streams_final_filt_final = []
        for sl in streams_final_filt:
            inds = np.dot(sl, lin_T)
            inds += offset
            if not inds.min().round(decimals=6) < 0:
                streams_final_filt_final.append(sl)

        # Save streamlines
        stf = StatefulTractogram(
            streams_final_filt_final,
            reference=t1w_brain_img,
            space=Space.VOXMM,
            origin=Origin.NIFTI,
        )
        stf.remove_invalid_streamlines()
        streams_final_filt_final = stf.streamlines
        save_tractogram(stf, streams_t1w, bbox_valid_check=True)
        warped_fa_img.uncache()

        # DSN QC plotting
        show_template_bundles(streams_final_filt_final, atlas_t1w,
                              streams_warp_png)

        nib.save(
            nib.Nifti1Image(
                density_map(streams_final_filt_final,
                            affine=np.eye(4),
                            vol_dims=warped_fa_shape),
                warped_fa_affine,
            ),
            density_t1w,
        )

        del (
            tractogram,
            streamlines,
            stf,
            streams_final_filt_final,
            streams_final_filt,
            streams_in_curr_grid,
            brain_mask,
        )

        gc.collect()

        assert len(coords) == len(labels)

        atlas_for_streams = atlas_t1w

    else:
        print(
            "Skipping Direct Streamline Normalization (DSN). Will proceed to "
            "define fiber connectivity in native diffusion space...")
        streams_t1w = streams
        warped_fa = fa_path
        atlas_for_streams = labels_im_file

    return (streams_t1w, dir_path, track_type, conn_model, subnet, node_radius,
            dens_thresh, ID, roi, min_span_tree, disp_filt, parc, prune, atlas,
            parcellation, labels, coords, norm, binary, atlas_for_streams,
            traversal, warped_fa, min_length)
Пример #21
0
def track(peaks,
          seed_image,
          max_nr_fibers=2000,
          smooth=None,
          compress=0.1,
          bundle_mask=None,
          start_mask=None,
          end_mask=None,
          tracking_uncertainties=None,
          dilation=0,
          next_step_displacement_std=0.15,
          nr_cpus=-1,
          verbose=True):
    """
    Generate streamlines.

    Great speedup was archived by:
    - only seeding in bundle_mask instead of entire image (seeding took very long)
    - calculating fiber length on the fly instead of using extra function which has to iterate over entire fiber a
    second time
    """

    peaks[:, :, :, 0] *= -1  # how to flip along x axis to work properly
    # Add +1 dilation for start and end mask to be more robust
    start_mask = binary_dilation(start_mask,
                                 iterations=dilation + 1).astype(np.uint8)
    end_mask = binary_dilation(end_mask,
                               iterations=dilation + 1).astype(np.uint8)
    if dilation > 0:
        bundle_mask = binary_dilation(bundle_mask,
                                      iterations=dilation).astype(np.uint8)

    if tracking_uncertainties is not None:
        tracking_uncertainties = img_utils.scale_to_range(
            tracking_uncertainties, range=(0, 1))

    global _PEAKS
    _PEAKS = peaks
    global _BUNDLE_MASK
    _BUNDLE_MASK = bundle_mask
    global _START_MASK
    _START_MASK = start_mask
    global _END_MASK
    _END_MASK = end_mask
    global _TRACKING_UNCERTAINTIES
    _TRACKING_UNCERTAINTIES = tracking_uncertainties

    # Get list of coordinates of each voxel in mask to seed from those
    mask_coords = np.array(np.where(bundle_mask == 1)).transpose()
    spacing = seed_image.header.get_zooms()[0]

    max_nr_seeds = 100 * max_nr_fibers  # after how many seeds to abort (to avoid endless runtime)
    # How many seeds to process in each pool.map iteration
    seeds_per_batch = 5000

    if nr_cpus == -1:
        nr_processes = psutil.cpu_count()
    else:
        nr_processes = nr_cpus

    streamlines = []
    fiber_ctr = 0
    seed_ctr = 0
    # Processing seeds in batches so we can stop after we reached desired nr of streamlines. Not ideal. Could be
    #   optimised by more multiprocessing fanciness.
    while fiber_ctr < max_nr_fibers:
        pool = multiprocessing.Pool(processes=nr_processes)
        streamlines_tmp = pool.map(
            partial(process_seedpoint,
                    next_step_displacement_std=next_step_displacement_std,
                    spacing=spacing),
            seed_generator(mask_coords, seeds_per_batch))
        # streamlines_tmp = [process_seedpoint(seed, spacing=spacing) for seed in
        #                    seed_generator(mask_coords, seeds_per_batch)] # single threaded for debugging
        pool.close()
        pool.join()

        streamlines_tmp = [sl for sl in streamlines_tmp
                           if len(sl) > 0]  # filter empty ones
        streamlines += streamlines_tmp
        fiber_ctr = len(streamlines)
        if verbose:
            print("nr_fibs: {}".format(fiber_ctr))
        seed_ctr += seeds_per_batch
        if seed_ctr > max_nr_seeds:
            if verbose:
                print("Early stopping because max nr of seeds reached.")
            break

    if verbose:
        print("final nr streamlines: {}".format(len(streamlines)))

    streamlines = streamlines[:
                              max_nr_fibers]  # remove surplus of fibers (comes from multiprocessing)
    streamlines = Streamlines(streamlines)  # Generate streamlines object

    # Move from convention "0mm is in voxel corner" to convention "0mm is in voxel center". Most toolkits use the
    # convention "0mm is in voxel center".
    streamlines = fiber_utils.add_to_each_streamline(streamlines, -0.5)

    # move streamlines to coordinate space
    #  This is doing: streamlines(coordinate_space) = affine * streamlines(voxel_space)
    streamlines = list(transform_streamlines(streamlines, seed_image.affine))

    # Smoothing does not change overall results at all because is just little smoothing. Just removes small unevenness.
    if smooth:
        streamlines = fiber_utils.smooth_streamlines(streamlines,
                                                     smoothing_factor=smooth)

    if compress:
        streamlines = fiber_utils.compress_streamlines(streamlines,
                                                       error_threshold=0.1,
                                                       nr_cpus=nr_cpus)

    return streamlines
Пример #22
0
    def setup(self, streamline, affine, evals=[0.001, 0, 0], sphere=None):
        """
        Set up the necessary components for the LiFE model: the matrix of
        fiber-contributions to the DWI signal, and the coordinates of voxels
        for which the equations will be solved

        Parameters
        ----------
        streamline : list
            Streamlines, each is an array of shape (n, 3)
        affine : 4 by 4 array
            Mapping from the streamline coordinates to the data
        evals : list (3 items, optional)
            The eigenvalues of the canonical tensor used as a response
            function. Default:[0.001, 0, 0].
        sphere: `dipy.core.Sphere` instance.
            Whether to approximate (and cache) the signal on a discrete
            sphere. This may confer a significant speed-up in setting up the
            problem, but is not as accurate. If `False`, we use the exact
            gradients along the streamlines to calculate the matrix, instead of
            an approximation. Defaults to use the 724-vertex symmetric sphere
            from :mod:`dipy.data`
        """
        if sphere is not False:
            SignalMaker = LifeSignalMaker(self.gtab,
                                          evals=evals,
                                          sphere=sphere)

        if affine is None:
            affine = np.eye(4)
        streamline = transform_streamlines(streamline, affine)
        # Assign some local variables, for shorthand:
        all_coords = np.concatenate(streamline)
        vox_coords = unique_rows(np.round(all_coords).astype(np.intp))
        del all_coords
        # We only consider the diffusion-weighted signals:
        n_bvecs = self.gtab.bvals[~self.gtab.b0s_mask].shape[0]
        v2f, v2fn = voxel2streamline(streamline,
                                     transformed=True,
                                     affine=affine,
                                     unique_idx=vox_coords)
        # How many fibers in each voxel (this will determine how many
        # components are in the matrix):
        n_unique_f = len(np.hstack(v2f.values()))
        # Preallocate these, which will be used to generate the sparse
        # matrix:
        f_matrix_sig = np.zeros(n_unique_f * n_bvecs, dtype=np.float)
        f_matrix_row = np.zeros(n_unique_f * n_bvecs, dtype=np.intp)
        f_matrix_col = np.zeros(n_unique_f * n_bvecs, dtype=np.intp)

        fiber_signal = []
        for s_idx, s in enumerate(streamline):
            if sphere is not False:
                fiber_signal.append(SignalMaker.streamline_signal(s))
            else:
                fiber_signal.append(streamline_signal(s, self.gtab, evals))

        del streamline
        if sphere is not False:
            del SignalMaker

        keep_ct = 0
        range_bvecs = np.arange(n_bvecs).astype(int)
        # In each voxel:
        for v_idx in range(vox_coords.shape[0]):
            mat_row_idx = (range_bvecs + v_idx * n_bvecs).astype(np.intp)
            # For each fiber in that voxel:
            for f_idx in v2f[v_idx]:
                # For each fiber-voxel combination, store the row/column
                # indices in the pre-allocated linear arrays
                f_matrix_row[keep_ct:keep_ct + n_bvecs] = mat_row_idx
                f_matrix_col[keep_ct:keep_ct + n_bvecs] = f_idx

                vox_fiber_sig = np.zeros(n_bvecs)
                for node_idx in v2fn[f_idx][v_idx]:
                    # Sum the signal from each node of the fiber in that voxel:
                    vox_fiber_sig += fiber_signal[f_idx][node_idx]
                # And add the summed thing into the corresponding rows:
                f_matrix_sig[keep_ct:keep_ct + n_bvecs] += vox_fiber_sig
                keep_ct = keep_ct + n_bvecs

        del v2f, v2fn
        # Allocate the sparse matrix, using the more memory-efficient 'csr'
        # format:
        life_matrix = sps.csr_matrix(
            (f_matrix_sig, [f_matrix_row, f_matrix_col]))

        return life_matrix, vox_coords
Пример #23
0
def transform_warp_sft(sft,
                       linear_transfo,
                       target,
                       inverse=False,
                       reverse_op=False,
                       deformation_data=None,
                       remove_invalid=True,
                       cut_invalid=False):
    """ Transform tractogram using a affine Subsequently apply a warp from
    antsRegistration (optional).
    Remove/Cut invalid streamlines to preserve sft validity.

    Parameters
    ----------
    sft: StatefulTractogram
        Stateful tractogram object containing the streamlines to transform.
    linear_transfo: numpy.ndarray
        Linear transformation matrix to apply to the tractogram.
    target: Nifti filepath, image object, header
        Final reference for the tractogram after registration.
    inverse: boolean
        Apply the inverse linear transformation.
    reverse_op: boolean
        Apply both transformation in the reverse order
    deformation_data: np.ndarray
        4D array containing a 3D displacement vector in each voxel.

    remove_invalid: boolean
        Remove the streamlines landing out of the bounding box.
    cut_invalid: boolean
        Cut invalid streamlines rather than removing them. Keep the longest
        segment only.

    Return
    ----------
    new_sft : StatefulTractogram

    """
    sft.to_rasmm()
    sft.to_center()

    if len(sft.streamlines) == 0:
        return StatefulTractogram(sft.streamlines, target, Space.RASMM)

    if inverse:
        linear_transfo = np.linalg.inv(linear_transfo)

    if not reverse_op:
        streamlines = transform_streamlines(sft.streamlines, linear_transfo)
    else:
        streamlines = sft.streamlines

    if deformation_data is not None:
        if not reverse_op:
            affine, _, _, _ = get_reference_info(target)
        else:
            affine = sft.affine

        # Because of duplication, an iteration over chunks of points is
        # necessary for a big dataset (especially if not compressed)
        streamlines = ArraySequence(streamlines)
        nb_points = len(streamlines._data)
        cur_position = 0
        chunk_size = 1000000
        nb_iteration = int(np.ceil(nb_points / chunk_size))
        inv_affine = np.linalg.inv(affine)

        while nb_iteration > 0:
            max_position = min(cur_position + chunk_size, nb_points)
            points = streamlines._data[cur_position:max_position]

            # To access the deformation information, we need to go in VOX space
            # No need for corner shift since we are doing interpolation
            cur_points_vox = np.array(transform_streamlines(
                points, inv_affine)).T

            x_def = map_coordinates(deformation_data[..., 0],
                                    cur_points_vox.tolist(),
                                    order=1)
            y_def = map_coordinates(deformation_data[..., 1],
                                    cur_points_vox.tolist(),
                                    order=1)
            z_def = map_coordinates(deformation_data[..., 2],
                                    cur_points_vox.tolist(),
                                    order=1)

            # ITK is in LPS and nibabel is in RAS, a flip is necessary for ANTs
            final_points = np.array([-1 * x_def, -1 * y_def, z_def])
            final_points += np.array(points).T

            streamlines._data[cur_position:max_position] = final_points.T
            cur_position = max_position
            nb_iteration -= 1

    if reverse_op:
        streamlines = transform_streamlines(streamlines, linear_transfo)

    new_sft = StatefulTractogram(streamlines,
                                 target,
                                 Space.RASMM,
                                 data_per_point=sft.data_per_point,
                                 data_per_streamline=sft.data_per_streamline)
    if cut_invalid:
        new_sft, _ = cut_invalid_streamlines(new_sft)
    elif remove_invalid:
        new_sft.remove_invalid_streamlines()

    return new_sft
Пример #24
0
affine = dix["affine"]

"""
Store the cingulum bundle. A bundle is a list of streamlines.
"""

bundle = dix["cg.left"]

"""
It happened that this bundle is in world coordinates and therefore we need to
transform it into native image coordinates so that it is in the same coordinate
space as the ``fa`` image.
"""

bundle_native = transform_streamlines(bundle, np.linalg.inv(affine))

"""
Show every streamline with an orientation color
===============================================

This is the default option when you are using ``line`` or ``streamtube``.
"""

renderer = window.Renderer()

stream_actor = actor.line(bundle_native)

renderer.set_camera(position=(-176.42, 118.52, 128.20), focal_point=(113.30, 128.31, 76.56), view_up=(0.18, 0.00, 0.98))

renderer.add(stream_actor)
Пример #25
0
from dipy.io.stateful_tractogram import Space, StatefulTractogram
from dipy.io.streamline import save_trk

prob_dg = ProbabilisticDirectionGetter.from_pmf(odf,
                                                max_angle=30.,
                                                sphere=sphere)
streamline_generator = LocalTracking(prob_dg,
                                     stopping_criterion,
                                     seeds,
                                     affine,
                                     step_size=.5)
streamlines = Streamlines(streamline_generator)

color = colormap.line_colors(streamlines)
streamlines_actor = actor.streamtube(list(
    transform_streamlines(streamlines, inv(t1_aff))),
                                     color,
                                     linewidth=0.1)

vol_actor = actor.slicer(t1_data)
vol_actor.display(x=40)
vol_actor2 = vol_actor.copy()
vol_actor2.display(z=35)

scene = window.Scene()
scene.add(vol_actor)
scene.add(vol_actor2)
scene.add(streamlines_actor)
if interactive:
    window.show(scene)
Пример #26
0
def bundle_analysis(model_bundle_folder,
                    bundle_folder,
                    orig_bundle_folder,
                    metric_folder,
                    group,
                    subject,
                    no_disks=100,
                    out_dir=''):
    """
    Applies statistical analysis on bundles and saves the results
    in a directory specified by ``out_dir``.

    Parameters
    ----------
    model_bundle_folder : string
        Path to the input model bundle files. This path may contain
        wildcards to process multiple inputs at once.
    bundle_folder : string
        Path to the input bundle files in common space. This path may
        contain wildcards to process multiple inputs at once.
    orig_folder : string
        Path to the input bundle files in native space. This path may
        contain wildcards to process multiple inputs at once.
    metric_folder : string
        Path to the input dti metric or/and peak files. It will be used as
        metric for statistical analysis of bundles.
    group : string
        what group subject belongs to e.g. control or patient
    subject : string
        subject id e.g. 10001
    no_disks : integer, optional
        Number of disks used for dividing bundle into disks. (Default 100)
    out_dir : string, optional
        Output directory (default input file directory)

    References
    ----------
    .. [Chandio19] Chandio, B.Q., S. Koudoro, D. Reagan, J. Harezlak,
    E. Garyfallidis, Bundle Analytics: a computational and statistical
    analyses framework for tractometric studies, Proceedings of:
    International Society of Magnetic Resonance in Medicine (ISMRM),
    Montreal, Canada, 2019.

    """

    dt = dict()

    mb = os.listdir(model_bundle_folder)
    mb.sort()
    bd = os.listdir(bundle_folder)
    bd.sort()
    org_bd = os.listdir(orig_bundle_folder)
    org_bd.sort()
    n = len(org_bd)

    for io in range(n):
        mbundles = load_tractogram(os.path.join(model_bundle_folder, mb[io]),
                                   'same',
                                   bbox_valid_check=False).streamlines
        bundles = load_tractogram(os.path.join(bundle_folder, bd[io]),
                                  'same',
                                  bbox_valid_check=False).streamlines
        orig_bundles = load_tractogram(os.path.join(orig_bundle_folder,
                                                    org_bd[io]),
                                       'same',
                                       bbox_valid_check=False).streamlines

        mbundle_streamlines = set_number_of_points(mbundles,
                                                   nb_points=no_disks)

        metric = AveragePointwiseEuclideanMetric()
        qb = QuickBundles(threshold=25., metric=metric)
        clusters = qb.cluster(mbundle_streamlines)
        centroids = Streamlines(clusters.centroids)

        print('Number of centroids ', len(centroids._data))
        print('Model bundle ', mb[io])
        print('Number of streamlines in bundle in common space ', len(bundles))
        print('Number of streamlines in bundle in original space ',
              len(orig_bundles))

        _, indx = cKDTree(centroids._data, 1,
                          copy_data=True).query(bundles._data, k=1)

        metric_files_names = os.listdir(metric_folder)
        _, affine = load_nifti(os.path.join(metric_folder, "fa.nii.gz"))

        affine_r = np.linalg.inv(affine)
        transformed_orig_bundles = transform_streamlines(
            orig_bundles, affine_r)

        for mn in range(0, len(metric_files_names)):

            ind = np.array(indx)
            fm = metric_files_names[mn][:2]
            bm = mb[io][:-4]
            dt = dict()
            metric_name = os.path.join(metric_folder, metric_files_names[mn])

            if metric_files_names[mn][2:] == '.nii.gz':
                metric, _ = load_nifti(metric_name)

                dti_measures(transformed_orig_bundles, metric, dt, fm, bm,
                             subject, group, ind, out_dir)

            else:
                fm = metric_files_names[mn][:3]
                metric = load_peaks(metric_name)
                peak_values(bundles, metric, dt, fm, bm, subject, group, ind,
                            out_dir)
if __name__ == '__mian__':
    from pyfat.io.load import load_tck
    from dipy.io.pickles import save_pickle
    from dipy.segment.clustering import QuickBundles
    from dipy.segment.metric import SumPointwiseEuclideanMetric

    # load fiber data
    data_path = '/home/brain/workingdir/data/dwi/hcp/' \
                'preprocessed/response_dhollander/101006/result/CC_fib.tck'
    imgtck = load_tck(data_path)

    world_coords = True
    if not world_coords:
        from dipy.tracking.streamline import transform_streamlines
        streamlines = transform_streamlines(imgtck.streamlines,
                                            np.linalg.inv(imgtck.affine))

    metric = SumPointwiseEuclideanMetric(feature=ArcLengthFeature())
    qb = QuickBundles(threshold=2., metric=metric)
    clusters = qb.cluster(streamlines)

    # extract > 100
    # print len(clusters) # 89
    for c in clusters:
        if len(c) < 100:
            clusters.remove_cluster(c)

    out_path = '/home/brain/workingdir/data/dwi/hcp/' \
               'preprocessed/response_dhollander/101006/result/CC_fib_length1_2.png'
    show(imgtck, clusters, out_path)
Пример #28
0
    def optimize(self, static, moving, mat=None):
        """ Find the minimum of the provided metric.

        Parameters
        ----------
        static : streamlines
            Reference or fixed set of streamlines.
        moving : streamlines
            Moving set of streamlines.
        mat : array
            Transformation (4, 4) matrix to start the registration. ``mat``
            is applied to moving. Default value None which means that initial
            transformation will be generated by shifting the centers of moving
            and static sets of streamlines to the origin.

        Returns
        -------
        map : StreamlineRegistrationMap

        """

        msg = 'need to have the same number of points. Use '
        msg += 'set_number_of_points from dipy.tracking.streamline'

        if not np.all(np.array(list(map(len, static))) == static[0].shape[0]):
            raise ValueError('Static streamlines ' + msg)

        if not np.all(np.array(list(map(len, moving))) == moving[0].shape[0]):
            raise ValueError('Moving streamlines ' + msg)

        if not np.all(np.array(list(map(len, moving))) == static[0].shape[0]):
            raise ValueError('Static and moving streamlines ' + msg)

        if mat is None:
            static_centered, static_shift = center_streamlines(static)
            moving_centered, moving_shift = center_streamlines(moving)
            static_mat = compose_matrix44(
                [static_shift[0], static_shift[1], static_shift[2], 0, 0, 0])

            moving_mat = compose_matrix44([
                -moving_shift[0], -moving_shift[1], -moving_shift[2], 0, 0, 0
            ])
        else:
            static_centered = static
            moving_centered = transform_streamlines(moving, mat)
            static_mat = np.eye(4)
            moving_mat = mat

        self.metric.setup(static_centered, moving_centered)

        distance = self.metric.distance

        if self.method == 'Powell':

            if self.options is None:
                self.options = {'xtol': 1e-6, 'ftol': 1e-6, 'maxiter': 1e6}

            opt = Optimizer(distance,
                            self.x0.tolist(),
                            method=self.method,
                            options=self.options,
                            evolution=self.evolution)

        if self.method == 'L-BFGS-B':

            if self.options is None:
                self.options = {
                    'maxcor': 10,
                    'ftol': 1e-7,
                    'gtol': 1e-5,
                    'eps': 1e-8,
                    'maxiter': 100
                }

            opt = Optimizer(distance,
                            self.x0.tolist(),
                            method=self.method,
                            bounds=self.bounds,
                            options=self.options,
                            evolution=self.evolution)

        if self.verbose:
            opt.print_summary()

        opt_mat = compose_matrix44(opt.xopt)

        mat = compose_transformations(moving_mat, opt_mat, static_mat)

        mat_history = []

        if opt.evolution is not None:
            for vecs in opt.evolution:
                mat_history.append(
                    compose_transformations(moving_mat, compose_matrix44(vecs),
                                            static_mat))

        srm = StreamlineRegistrationMap(mat, opt.xopt, opt.fopt, mat_history,
                                        opt.nfev, opt.nit)
        del opt
        return srm
Пример #29
0
    def run(self, static_files, moving_files,
            x0='affine',
            rm_small_clusters=50,
            qbx_thr=[40, 30, 20, 15],
            num_threads=None,
            greater_than=50,
            less_than=250,
            nb_pts=20,
            progressive=True,
            out_dir='',
            out_moved='moved.trk',
            out_affine='affine.txt',
            out_stat_centroids='static_centroids.trk',
            out_moving_centroids='moving_centroids.trk',
            out_moved_centroids='moved_centroids.trk'):
        """ Streamline-based linear registration.

        For efficiency we apply the registration on cluster centroids and
        remove small clusters.

        Parameters
        ----------
        static_files : string
        moving_files : string
        x0 : string, optional
            rigid, similarity or affine transformation model (default affine)
        rm_small_clusters : int, optional
            Remove clusters that have less than `rm_small_clusters`
            (default 50)
        qbx_thr : variable int, optional
            Thresholds for QuickBundlesX (default [40, 30, 20, 15])
        num_threads : int, optional
            Number of threads. If None (default) then all available threads
            will be used. Only metrics using OpenMP will use this variable.
        greater_than : int, optional
            Keep streamlines that have length greater than
            this value (default 50)
        less_than : int, optional
            Keep streamlines have length less than this value (default 250)
        np_pts : int, optional
            Number of points for discretizing each streamline (default 20)
        progressive : boolean, optional
            (default True)
        out_dir : string, optional
            Output directory (default input file directory)
        out_moved : string, optional
            Filename of moved tractogram (default 'moved.trk')
        out_affine : string, optional
            Filename of affine for SLR transformation (default 'affine.txt')
        out_stat_centroids : string, optional
            Filename of static centroids (default 'static_centroids.trk')
        out_moving_centroids : string, optional
            Filename of moving centroids (default 'moving_centroids.trk')
        out_moved_centroids : string, optional
            Filename of moved centroids (default 'moved_centroids.trk')

        Notes
        -----
        The order of operations is the following. First short or long
        streamlines are removed. Second the tractogram or a random selection
        of the tractogram is clustered with QuickBundlesX. Then SLR
        [Garyfallidis15]_ is applied.

        References
        ----------
        .. [Garyfallidis15] Garyfallidis et al. "Robust and efficient linear
        registration of white-matter fascicles in the space of
        streamlines", NeuroImage, 117, 124--140, 2015

        .. [Garyfallidis14] Garyfallidis et al., "Direct native-space fiber
        bundle alignment for group comparisons", ISMRM, 2014.

        .. [Garyfallidis17] Garyfallidis et al. Recognition of white matter
        bundles using local and global streamline-based registration
        and clustering, Neuroimage, 2017.
        """
        io_it = self.get_io_iterator()

        logging.info("QuickBundlesX clustering is in use")
        logging.info('QBX thresholds {0}'.format(qbx_thr))

        for static_file, moving_file, out_moved_file, out_affine_file, \
                static_centroids_file, moving_centroids_file, \
                moved_centroids_file in io_it:

            logging.info('Loading static file {0}'.format(static_file))
            logging.info('Loading moving file {0}'.format(moving_file))

            static, static_header = load_trk(static_file)
            moving, moving_header = load_trk(moving_file)

            moved, affine, centroids_static, centroids_moving = \
                slr_with_qbx(
                    static, moving, x0, rm_small_clusters=rm_small_clusters,
                    greater_than=greater_than, less_than=less_than,
                    qbx_thr=qbx_thr)

            logging.info('Saving output file {0}'.format(out_moved_file))
            save_trk(out_moved_file, moved, affine=np.eye(4),
                     header=static_header)

            logging.info('Saving output file {0}'.format(out_affine_file))
            np.savetxt(out_affine_file, affine)

            logging.info('Saving output file {0}'
                         .format(static_centroids_file))
            save_trk(static_centroids_file, centroids_static, affine=np.eye(4),
                     header=static_header)

            logging.info('Saving output file {0}'
                         .format(moving_centroids_file))
            save_trk(moving_centroids_file, centroids_moving,
                     affine=np.eye(4),
                     header=static_header)

            centroids_moved = transform_streamlines(centroids_moving, affine)

            logging.info('Saving output file {0}'
                         .format(moved_centroids_file))
            save_trk(moved_centroids_file, centroids_moved, affine=np.eye(4),
                     header=static_header)
Пример #30
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_tractogram, args.in_transfo])
    assert_outputs_exist(parser, args, args.out_tractogram)

    if args.verbose:
        log_level = logging.INFO
        logging.basicConfig(level=log_level)

    wb_file = load_tractogram_with_reference(parser, args, args.in_tractogram)
    wb_streamlines = wb_file.streamlines
    model_file = load_tractogram_with_reference(parser, args, args.in_model)

    transfo = load_matrix_in_any_format(args.in_transfo)
    if args.inverse:
        transfo = np.linalg.inv(load_matrix_in_any_format(args.in_transfo))

    before, after = compute_distance_barycenters(wb_file, model_file, transfo)
    if after > before:
        logging.warning('The distance between volumes barycenter should be '
                        'lower after registration. Maybe try using/removing '
                        '--inverse.')
        logging.info('Distance before: {}, Distance after: {}'.format(
            np.round(before, 3), np.round(after, 3)))
    model_streamlines = transform_streamlines(model_file.streamlines, transfo)

    rng = np.random.RandomState(args.seed)
    if args.in_pickle:
        with open(args.in_pickle, 'rb') as infile:
            cluster_map = pickle.load(infile)
        reco_obj = RecoBundles(wb_streamlines,
                               cluster_map=cluster_map,
                               rng=rng,
                               less_than=1,
                               verbose=args.verbose)
    else:
        reco_obj = RecoBundles(wb_streamlines,
                               clust_thr=args.tractogram_clustering_thr,
                               rng=rng,
                               greater_than=1,
                               verbose=args.verbose)

    if args.out_pickle:
        with open(args.out_pickle, 'wb') as outfile:
            pickle.dump(reco_obj.cluster_map, outfile)
    _, indices = reco_obj.recognize(ArraySequence(model_streamlines),
                                    args.model_clustering_thr,
                                    pruning_thr=args.pruning_thr,
                                    slr_num_threads=args.slr_threads)
    new_streamlines = wb_streamlines[indices]
    new_data_per_streamlines = wb_file.data_per_streamline[indices]
    new_data_per_points = wb_file.data_per_point[indices]

    if not args.no_empty or new_streamlines:
        sft = StatefulTractogram(new_streamlines,
                                 wb_file.space_attributes,
                                 Space.RASMM,
                                 data_per_streamline=new_data_per_streamlines,
                                 data_per_point=new_data_per_points)
        save_tractogram(sft, args.out_tractogram)
Пример #31
0
def warp_tractogram(streamlines, transfo, deformation_data, source):
    """
    Warp tractogram using a deformation map.
    Support Ants and Dipy deformation map.
    Apply warp in-place

    Parameters
    ----------
    streamlines: list or ArraySequence
        Streamlines as loaded by the nibabel API (RASMM)
    transfo: numpy.ndarray
        Transformation matrix to bring streamlines from RASMM to Voxel space
    deformation_data: numpy.ndarray
        4D numpy array containing a 3D displacement vector in each voxel
    source: str
        Source of the deformation map [ants, dipy]
    """

    if source == 'ants':
        flip = [-1, -1, 1]
    elif source == 'dipy':
        flip = [1, 1, 1]

    # Because of duplication, an iteration over chunks of points is necessary
    # for a big dataset (especially if not compressed)
    nb_points = len(streamlines._data)
    current_position = 0
    chunk_size = 1000000
    nb_iteration = int(np.ceil(nb_points/chunk_size))
    inv_transfo = np.linalg.inv(transfo)

    while nb_iteration > 0:
        max_position = min(current_position + chunk_size, nb_points)
        streamline = streamlines._data[current_position:max_position]

        # To access the deformation information, we need to go in voxel space
        streamline_vox = transform_streamlines(streamline,
                                               inv_transfo)

        current_streamline_vox = np.array(streamline_vox).T
        current_streamline_vox_list = current_streamline_vox.tolist()

        x_def = ndimage.map_coordinates(deformation_data[..., 0],
                                        current_streamline_vox_list, order=1)
        y_def = ndimage.map_coordinates(deformation_data[..., 1],
                                        current_streamline_vox_list, order=1)
        z_def = ndimage.map_coordinates(deformation_data[..., 2],
                                        current_streamline_vox_list, order=1)

        # ITK is in LPS and nibabel is in RAS, a flip is necessary for ANTs
        final_streamline = np.array([flip[0]*x_def,
                                     flip[1]*y_def,
                                     flip[2]*z_def])

        # The deformation obtained is in worldSpace
        if source == 'ants':
            final_streamline += np.array(streamline).T
        elif source == 'dipy':
            final_streamline += current_streamline_vox
            # The tractogram need to be brought back in world space to be saved
            final_streamline = transform_streamlines(final_streamline,
                                                     transfo)

        streamlines._data[current_position:max_position] \
            = final_streamline.T
        current_position = max_position
        nb_iteration -= 1
Пример #32
0
		trk_preprocess_postrigid_affine = os.path.join(path_trk_tempdir, f'{subj}_preprocess_postrigid_affine.trk')

		#trans = os.path.join(work_dir, "preprocess", "base_images", "translation_xforms", f"{subj}_0DerivedInitialMovingTranslation.mat")
		#rigid = os.path.join(work_dir, "dwi", f"{subj}_rigid.mat")
		#affine = os.path.join(work_dir, "dwi", f"{subj}_affine.mat")
		#runno_to_MDT = os.path.join(work_dir, f'dwi/SyN_0p5_3_0p5_fa/faMDT_NoNameYet_n37_i6/reg_diffeo/{subj}_to_MDT_warp.nii.gz')


		_, exists = check_files([trans, rigid, affine, runno_to_MDT])
		if np.any(exists==0):
		    raise Exception('missing transform file')

		affine_map_test = get_affine_transform(reference, subj_dwi)
		streamlines, header = unload_trk(subj_trk)

		tmp2_streamlines = transform_streamlines(streamlines, np.linalg.inv(affine_map_test), in_place=False)

		if (not os.path.exists(trk_filepath_tmp2) or overwrite) and save_temp_files:
		    save_trk_header(filepath= trk_filepath_tmp2, streamlines = tmp2_streamlines, header = header, affine=np.eye(4), verbose=verbose)

		#tmp2_streamlines, header_new = unload_trk(trk_filepath_tmp2)
		affine_transform, newaffine = get_flip_affine(orientation_in, orientation_out)
		affine_transform_new = affine_transform

		#affine_transform[:3,3] = affine_transform.diagonal()[:3]* header[0][:3,3] - header[0][:3,3]

		center1 = header[0][:3, 3]
		center2 = affine_transform.diagonal()[:3] * center1[:3]
		affine_transform_new[:3, 3] = center1 - center2
		affine_transform_new[1, 3] = affine_transform[0, 3]
Пример #33
0
###############################################################################
# With our current design it is easy to decide in which space you want the
# streamlines and slices to appear. The default we have here is to appear in
# world coordinates (RAS 1mm).

world_coords = True

###############################################################################
# If we want to see the objects in native space we need to make sure that all
# objects which are currently in world coordinates are transformed back to
# native space using the inverse of the affine.


if not world_coords:
    from dipy.tracking.streamline import transform_streamlines
    streamlines = transform_streamlines(streamlines, np.linalg.inv(affine))

###############################################################################
# Now we create, a ``Scene`` object and add the streamlines using the
# ``line`` function and an image plane using the ``slice`` function.

scene = window.Scene()
stream_actor = actor.line(streamlines)

if not world_coords:
    image_actor_z = actor.slicer(data, affine=np.eye(4))
else:
    image_actor_z = actor.slicer(data, affine)

###############################################################################
# We can also change also the opacity of the slicer.
Пример #34
0
    def run(self,
            static_files,
            moving_files,
            x0='affine',
            rm_small_clusters=50,
            qbx_thr=[40, 30, 20, 15],
            num_threads=None,
            greater_than=50,
            less_than=250,
            nb_pts=20,
            progressive=True,
            out_dir='',
            out_moved='moved.trk',
            out_affine='affine.txt',
            out_stat_centroids='static_centroids.trk',
            out_moving_centroids='moving_centroids.trk',
            out_moved_centroids='moved_centroids.trk'):
        """ Streamline-based linear registration.

        For efficiency we apply the registration on cluster centroids and
        remove small clusters.

        Parameters
        ----------
        static_files : string
        moving_files : string
        x0 : string, optional
            rigid, similarity or affine transformation model (default affine)
        rm_small_clusters : int, optional
            Remove clusters that have less than `rm_small_clusters`
            (default 50)
        qbx_thr : variable int, optional
            Thresholds for QuickBundlesX (default [40, 30, 20, 15])
        num_threads : int, optional
            Number of threads. If None (default) then all available threads
            will be used. Only metrics using OpenMP will use this variable.
        greater_than : int, optional
            Keep streamlines that have length greater than
            this value (default 50)
        less_than : int, optional
            Keep streamlines have length less than this value (default 250)
        np_pts : int, optional
            Number of points for discretizing each streamline (default 20)
        progressive : boolean, optional
            (default True)
        out_dir : string, optional
            Output directory (default input file directory)
        out_moved : string, optional
            Filename of moved tractogram (default 'moved.trk')
        out_affine : string, optional
            Filename of affine for SLR transformation (default 'affine.txt')
        out_stat_centroids : string, optional
            Filename of static centroids (default 'static_centroids.trk')
        out_moving_centroids : string, optional
            Filename of moving centroids (default 'moving_centroids.trk')
        out_moved_centroids : string, optional
            Filename of moved centroids (default 'moved_centroids.trk')

        Notes
        -----
        The order of operations is the following. First short or long
        streamlines are removed. Second the tractogram or a random selection
        of the tractogram is clustered with QuickBundlesX. Then SLR
        [Garyfallidis15]_ is applied.

        References
        ----------
        .. [Garyfallidis15] Garyfallidis et al. "Robust and efficient linear
                registration of white-matter fascicles in the space of
                streamlines", NeuroImage, 117, 124--140, 2015
        .. [Garyfallidis14] Garyfallidis et al., "Direct native-space fiber
                bundle alignment for group comparisons", ISMRM, 2014.
        .. [Garyfallidis17] Garyfallidis et al. Recognition of white matter
                bundles using local and global streamline-based registration
                and clustering, Neuroimage, 2017.
        """
        io_it = self.get_io_iterator()

        logging.info("QuickBundlesX clustering is in use")
        logging.info('QBX thresholds {0}'.format(qbx_thr))

        for static_file, moving_file, out_moved_file, out_affine_file, \
                static_centroids_file, moving_centroids_file, \
                moved_centroids_file in io_it:

            logging.info('Loading static file {0}'.format(static_file))
            logging.info('Loading moving file {0}'.format(moving_file))

            static, static_header = load_trk(static_file)
            moving, moving_header = load_trk(moving_file)

            moved, affine, centroids_static, centroids_moving = \
                slr_with_qbx(
                    static, moving, x0, rm_small_clusters=rm_small_clusters,
                    greater_than=greater_than, less_than=less_than,
                    qbx_thr=qbx_thr)

            logging.info('Saving output file {0}'.format(out_moved_file))
            save_trk(out_moved_file,
                     moved,
                     affine=np.eye(4),
                     header=static_header)

            logging.info('Saving output file {0}'.format(out_affine_file))
            np.savetxt(out_affine_file, affine)

            logging.info(
                'Saving output file {0}'.format(static_centroids_file))
            save_trk(static_centroids_file,
                     centroids_static,
                     affine=np.eye(4),
                     header=static_header)

            logging.info(
                'Saving output file {0}'.format(moving_centroids_file))
            save_trk(moving_centroids_file,
                     centroids_moving,
                     affine=np.eye(4),
                     header=static_header)

            centroids_moved = transform_streamlines(centroids_moving, affine)

            logging.info('Saving output file {0}'.format(moved_centroids_file))
            save_trk(moved_centroids_file,
                     centroids_moved,
                     affine=np.eye(4),
                     header=static_header)
Пример #35
0
def test_bundle_maps():
    renderer = window.renderer()
    bundle = fornix_streamlines()
    bundle, shift = center_streamlines(bundle)

    mat = np.array([[1, 0, 0, 100],
                    [0, 1, 0, 100],
                    [0, 0, 1, 100],
                    [0, 0, 0, 1.]])

    bundle = transform_streamlines(bundle, mat)

    # metric = np.random.rand(*(200, 200, 200))
    metric = 100 * np.ones((200, 200, 200))

    # add lower values
    metric[100, :, :] = 100 * 0.5

    # create a nice orange-red colormap
    lut = actor.colormap_lookup_table(scale_range=(0., 100.),
                                      hue_range=(0., 0.1),
                                      saturation_range=(1, 1),
                                      value_range=(1., 1))

    line = actor.line(bundle, metric, linewidth=0.1, lookup_colormap=lut)
    window.add(renderer, line)
    window.add(renderer, actor.scalar_bar(lut, ' '))

    report = window.analyze_renderer(renderer)

    npt.assert_almost_equal(report.actors, 1)
    # window.show(renderer)

    renderer.clear()

    nb_points = np.sum([len(b) for b in bundle])
    values = 100 * np.random.rand(nb_points)
    # values[:nb_points/2] = 0

    line = actor.streamtube(bundle, values, linewidth=0.1, lookup_colormap=lut)
    renderer.add(line)
    # window.show(renderer)

    report = window.analyze_renderer(renderer)
    npt.assert_equal(report.actors_classnames[0], 'vtkLODActor')

    renderer.clear()

    colors = np.random.rand(nb_points, 3)
    # values[:nb_points/2] = 0

    line = actor.line(bundle, colors, linewidth=2)
    renderer.add(line)
    # window.show(renderer)

    report = window.analyze_renderer(renderer)
    npt.assert_equal(report.actors_classnames[0], 'vtkLODActor')
    # window.show(renderer)

    arr = window.snapshot(renderer)
    report2 = window.analyze_snapshot(arr)
    npt.assert_equal(report2.objects, 1)

    # try other input options for colors
    renderer.clear()
    actor.line(bundle, (1., 0.5, 0))
    actor.line(bundle, np.arange(len(bundle)))
    actor.line(bundle)
    colors = [np.random.rand(*b.shape) for b in bundle]
    actor.line(bundle, colors=colors)
Пример #36
0
    def run(self, static_files, moving_files,
            x0='affine',
            rm_small_clusters=50,
            qbx_thr=[40, 30, 20, 15],
            num_threads=None,
            greater_than=50,
            less_than=250,
            nb_pts=20,
            progressive=True,
            out_dir='',
            out_moved='moved.trk',
            out_affine='affine.txt',
            out_stat_centroids='static_centroids.trk',
            out_moving_centroids='moving_centroids.trk',
            out_moved_centroids='moved_centroids.trk'):
        """ Streamline-based linear registration.

        For efficiency we apply the registration on cluster centroids and
        remove small clusters.

        Parameters
        ----------
        static_files : string
        moving_files : string
        x0 : string, optional
            rigid, similarity or affine transformation model.
        rm_small_clusters : int, optional
            Remove clusters that have less than `rm_small_clusters`.
        qbx_thr : variable int, optional
            Thresholds for QuickBundlesX.
        num_threads : int, optional
            Number of threads to be used for OpenMP parallelization. If None
            (default) the value of OMP_NUM_THREADS environment variable is
            used if it is set, otherwise all available threads are used. If
            < 0 the maximal number of threads minus |num_threads + 1| is used
            (enter -1 to use as many threads as possible). 0 raises an error.
            Only metrics using OpenMP will use this variable.
        greater_than : int, optional
            Keep streamlines that have length greater than
            this value.
        less_than : int, optional
            Keep streamlines have length less than this value.
        np_pts : int, optional
            Number of points for discretizing each streamline.
        progressive : boolean, optional
        out_dir : string, optional
            Output directory. (default current directory)
        out_moved : string, optional
            Filename of moved tractogram.
        out_affine : string, optional
            Filename of affine for SLR transformation.
        out_stat_centroids : string, optional
            Filename of static centroids.
        out_moving_centroids : string, optional
            Filename of moving centroids.
        out_moved_centroids : string, optional
            Filename of moved centroids.

        Notes
        -----
        The order of operations is the following. First short or long
        streamlines are removed. Second the tractogram or a random selection
        of the tractogram is clustered with QuickBundlesX. Then SLR
        [Garyfallidis15]_ is applied.

        References
        ----------
        .. [Garyfallidis15] Garyfallidis et al. "Robust and efficient linear
        registration of white-matter fascicles in the space of
        streamlines", NeuroImage, 117, 124--140, 2015

        .. [Garyfallidis14] Garyfallidis et al., "Direct native-space fiber
        bundle alignment for group comparisons", ISMRM, 2014.

        .. [Garyfallidis17] Garyfallidis et al. Recognition of white matter
        bundles using local and global streamline-based registration
        and clustering, NeuroImage, 2017.
        """

        io_it = self.get_io_iterator()

        logging.info("QuickBundlesX clustering is in use")
        logging.info('QBX thresholds {0}'.format(qbx_thr))

        for static_file, moving_file, out_moved_file, out_affine_file, \
                static_centroids_file, moving_centroids_file, \
                moved_centroids_file in io_it:

            logging.info('Loading static file {0}'.format(static_file))
            logging.info('Loading moving file {0}'.format(moving_file))

            static_obj = nib.streamlines.load(static_file)
            moving_obj = nib.streamlines.load(moving_file)

            static, static_header = static_obj.streamlines, static_obj.header
            moving, moving_header = moving_obj.streamlines, moving_obj.header

            moved, affine, centroids_static, centroids_moving = \
                slr_with_qbx(
                    static, moving, x0, rm_small_clusters=rm_small_clusters,
                    greater_than=greater_than, less_than=less_than,
                    qbx_thr=qbx_thr)

            logging.info('Saving output file {0}'.format(out_moved_file))
            new_tractogram = nib.streamlines.Tractogram(moved,
                                                        affine_to_rasmm=np.eye(4))
            nib.streamlines.save(new_tractogram, out_moved_file,
                                 header=moving_header)

            logging.info('Saving output file {0}'.format(out_affine_file))
            np.savetxt(out_affine_file, affine)

            logging.info('Saving output file {0}'
                         .format(static_centroids_file))
            new_tractogram = nib.streamlines.Tractogram(centroids_static,
                                                        affine_to_rasmm=np.eye(4))
            nib.streamlines.save(new_tractogram, static_centroids_file,
                                 header=static_header)

            logging.info('Saving output file {0}'
                         .format(moving_centroids_file))
            new_tractogram = nib.streamlines.Tractogram(centroids_moving,
                                                        affine_to_rasmm=np.eye(4))
            nib.streamlines.save(new_tractogram, moving_centroids_file,
                                 header=moving_header)

            centroids_moved = transform_streamlines(centroids_moving, affine)

            logging.info('Saving output file {0}'
                         .format(moved_centroids_file))

            new_tractogram = nib.streamlines.Tractogram(centroids_moved,
                                                        affine_to_rasmm=np.eye(4))
            nib.streamlines.save(new_tractogram, moved_centroids_file,
                                 header=moving_header)
Пример #37
0
    def setup(self, streamline, affine, evals=[0.001, 0, 0], sphere=None):
        """
        Set up the necessary components for the LiFE model: the matrix of
        fiber-contributions to the DWI signal, and the coordinates of voxels
        for which the equations will be solved

        Parameters
        ----------
        streamline : list
            Streamlines, each is an array of shape (n, 3)
        affine : 4 by 4 array
            Mapping from the streamline coordinates to the data
        evals : list (3 items, optional)
            The eigenvalues of the canonical tensor used as a response function

        sphere: `dipy.core.Sphere` instance.
            Whether to approximate (and cache) the signal on a discrete
            sphere. This may confer a significant speed-up in setting up the
            problem, but is not as accurate. If `False`, we use the exact
            gradients along the streamlines to calculate the matrix, instead of
            an approximation. Defaults to use the 724-vertex symmetric sphere
            from :mod:`dipy.data`
        """
        if sphere is not False:
            SignalMaker = LifeSignalMaker(self.gtab,
                                          evals=evals,
                                          sphere=sphere)

        if affine is None:
            affine = np.eye(4)
        streamline = transform_streamlines(streamline, affine)
        # Assign some local variables, for shorthand:
        all_coords = np.concatenate(streamline)
        vox_coords = unique_rows(all_coords.astype(int))
        n_vox = vox_coords.shape[0]
        # We only consider the diffusion-weighted signals:
        n_bvecs = self.gtab.bvals[~self.gtab.b0s_mask].shape[0]

        v2f, v2fn = voxel2streamline(streamline, transformed=True,
                                     affine=affine, unique_idx=vox_coords)

        # How many fibers in each voxel (this will determine how many
        # components are in the fiber part of the matrix):
        n_unique_f = np.sum(v2f)

        # Preallocate these, which will be used to generate the two sparse
        # matrices:

        # This one will hold the fiber-predicted signal
        f_matrix_sig = np.zeros(n_unique_f * n_bvecs)
        f_matrix_row = np.zeros(n_unique_f * n_bvecs)
        f_matrix_col = np.zeros(n_unique_f * n_bvecs)

        keep_ct = 0
        if sphere is not False:
            fiber_signal = [SignalMaker.streamline_signal(s) for s in streamline]
        else:
            fiber_signal = [streamline_signal(s, self.gtab, evals)
                            for s in streamline]

        # In each voxel:
        for v_idx, vox in enumerate(vox_coords):
            # dbg:
            # print(100*float(v_idx)/n_vox)
            # For each fiber:
            for f_idx in np.where(v2f[v_idx])[0]:
                # Sum the signal from each node of the fiber in that voxel:
                vox_fiber_sig = np.zeros(n_bvecs)
                for node_idx in np.where(v2fn[f_idx] == v_idx)[0]:
                    this_signal = fiber_signal[f_idx][node_idx]
                    vox_fiber_sig += (this_signal - np.mean(this_signal))
                # For each fiber-voxel combination, we now store the row/column
                # indices and the signal in the pre-allocated linear arrays
                f_matrix_row[keep_ct:keep_ct+n_bvecs] =\
                    np.arange(n_bvecs) + v_idx * n_bvecs
                f_matrix_col[keep_ct:keep_ct+n_bvecs] =\
                    np.ones(n_bvecs) * f_idx
                f_matrix_sig[keep_ct:keep_ct+n_bvecs] = vox_fiber_sig
                keep_ct += n_bvecs

        # Allocate the sparse matrix, using the more memory-efficient 'csr'
        # format (converted from the coo format, which we rely on for the
        # initial allocation):
        life_matrix = sps.coo_matrix((f_matrix_sig,
                                      [f_matrix_row, f_matrix_col])).tocsr()

        return life_matrix, vox_coords
Пример #38
0
    def setup(self,
              streamline,
              affine,
              evals=[0.001, 0, 0],
              sphere=None,
              processes=1,
              verbose=False):
        """
        Set up the necessary components for the LiFE model: the matrix of
        fiber-contributions to the DWI signal, and the coordinates of voxels
        for which the equations will be solved

        Parameters
        ----------
        streamline : list
            Streamlines, each is an array of shape (n, 3)
        affine : array_like (4, 4)
            The mapping from voxel coordinates to streamline points.
            The voxel_to_rasmm matrix, typically from a NIFTI file.
        evals : list (3 items, optional)
            The eigenvalues of the canonical tensor used as a response
            function. Default:[0.001, 0, 0].
        sphere: `dipy.core.Sphere` instance.
            Whether to approximate (and cache) the signal on a discrete
            sphere. This may confer a significant speed-up in setting up the
            problem, but is not as accurate. If `False`, we use the exact
            gradients along the streamlines to calculate the matrix, instead of
            an approximation. Defaults to use the 724-vertex symmetric sphere
            from :mod:`dipy.data`
        """
        if sphere is not False:
            SignalMaker = LifeSignalMaker(self.gtab,
                                          evals=evals,
                                          sphere=sphere)

        streamline = transform_streamlines(streamline, affine)

        #picklepath1 = '/Users/alex/jacques/fiber_signal_parallel.p'
        #fiber_signal=pickle.load(open(picklepath1, "rb"))
        #picklepath2 = '/Users/alex/jacques/fiber_signal_orig.p'
        #fiber_signal_orig=pickle.load(open(picklepath2, "rb"))

        #original location of the vox steps, moved them for faster streamline processing debug
        fiber_signal = []
        fiber_signal_orig = []
        fiber_signal_list = []
        skiplist = []

        #the stuff that got moved around for faster processing
        # Assign some local variables, for shorthand:
        #if save_fibsignal:
        #    pickle.dump(fiber_signal, open(picklepath1, "wb"))

        all_coords = np.concatenate(streamline)
        vox_coords = unique_rows(np.round(all_coords).astype(np.intp))
        del all_coords
        # We only consider the diffusion-weighted signals:
        n_bvecs = self.gtab.bvals[~self.gtab.b0s_mask].shape[0]
        v2f, v2fn = voxel2streamline(streamline,
                                     np.eye(4),
                                     unique_idx=vox_coords)
        # How many fibers in each voxel (this will determine how many
        # components are in the matrix):
        n_unique_f = len(np.hstack(v2f.values()))
        """

        save_fibsignal=True
        picklepath1 = '/Users/alex/jacques/fiber_signal_parallel_rev.p'
        picklepath2 = '/Users/alex/jacques/fiber_signal_parallel.p'
        try:
            fiber_signal = pickle.load(open(picklepath1, "rb"))
            save_fibsignal=False
            print("getting Fiber signal from" + picklepath1)
            fiber_signal_orig = pickle.load(open(picklepath2, "rb"))
        except FileNotFoundError:
        """
        print("computing the fiber signal values")
        duration1 = time()
        if processes > 1:
            pool = mp.Pool(processes)
            fiber_signal = pool.starmap_async(
                fiber_treatment,
                [(s, idx, self.gtab, evals, SignalMaker, sphere)
                 for idx, s in enumerate(streamline)]).get()
            #for idx,fiber in enumerate(fiber_signal_list):
            pool.close()
        else:
            for s_idx, s in enumerate(streamline):
                streamshape = np.shape(np.asarray(s))
                if sphere is not False:
                    fiber_signal.append(SignalMaker.streamline_signal(s))
                else:
                    fiber_signal.append(streamline_signal(s, self.gtab, evals))
                #print("Took care of "+ str(s_idx) + " out of " + str(len(streamline)) + " streamlines")

        if verbose:
            print("Obtaining fiber signal process done in " +
                  str(time() - duration1) + "s")
            if len(skiplist) > 0:
                print("Skipped " + str(len(skiplist)) + " out of " +
                      str(len(streamline)) +
                      "due to size constraints (tiny streamlines)")

        if sphere is not False:
            del SignalMaker

        # Preallocate these, which will be used to generate the sparse
        # matrix:
        f_matrix_sig = np.zeros(n_unique_f * n_bvecs, dtype=np.float)
        f_matrix_row = np.zeros(n_unique_f * n_bvecs, dtype=np.intp)
        f_matrix_col = np.zeros(n_unique_f * n_bvecs, dtype=np.intp)
        #end of moved block JS
        del streamline

        keep_ct = 0
        range_bvecs = np.arange(n_bvecs).astype(int)
        # In each voxel:
        for v_idx in range(vox_coords.shape[0]):
            mat_row_idx = (range_bvecs + v_idx * n_bvecs).astype(np.intp)
            # For each fiber in that voxel:
            for f_idx in v2f[v_idx]:
                # For each fiber-voxel combination, store the row/column
                # indices in the pre-allocated linear arrays
                f_matrix_row[keep_ct:keep_ct + n_bvecs] = mat_row_idx
                f_matrix_col[keep_ct:keep_ct + n_bvecs] = f_idx

                vox_fiber_sig = np.zeros(n_bvecs)
                for node_idx in v2fn[f_idx][v_idx]:
                    # Sum the signal from each node of the fiber in that voxel:
                    try:
                        vox_fiber_sig += fiber_signal[f_idx][node_idx]
                    except IndexError:
                        print("hi")
                        raise IndexError
                # And add the summed thing into the corresponding rows:
                f_matrix_sig[keep_ct:keep_ct + n_bvecs] += vox_fiber_sig
                keep_ct = keep_ct + n_bvecs

        del v2f, v2fn
        # Allocate the sparse matrix, using the more memory-efficient 'csr'
        # format:
        life_matrix = sps.csr_matrix(
            (f_matrix_sig, [f_matrix_row, f_matrix_col]))

        return life_matrix, vox_coords
Пример #39
0
    def optimize(self, static, moving, mat=None):
        """ Find the minimum of the provided metric.

        Parameters
        ----------
        static : streamlines
            Reference or fixed set of streamlines.
        moving : streamlines
            Moving set of streamlines.
        mat : array
            Transformation (4, 4) matrix to start the registration. ``mat``
            is applied to moving. Default value None which means that initial
            transformation will be generated by shifting the centers of moving
            and static sets of streamlines to the origin.

        Returns
        -------
        map : StreamlineRegistrationMap

        """

        msg = 'need to have the same number of points. Use '
        msg += 'set_number_of_points from dipy.tracking.streamline'

        if not np.all(np.array(list(map(len, static))) == static[0].shape[0]):
            raise ValueError('Static streamlines ' + msg)

        if not np.all(np.array(list(map(len, moving))) == moving[0].shape[0]):
            raise ValueError('Moving streamlines ' + msg)

        if not np.all(np.array(list(map(len, moving))) == static[0].shape[0]):
            raise ValueError('Static and moving streamlines ' + msg)

        if mat is None:
            static_centered, static_shift = center_streamlines(static)
            moving_centered, moving_shift = center_streamlines(moving)
            static_mat = compose_matrix44([static_shift[0], static_shift[1],
                                           static_shift[2], 0, 0, 0])

            moving_mat = compose_matrix44([-moving_shift[0], -moving_shift[1],
                                           -moving_shift[2], 0, 0, 0])
        else:
            static_centered = static
            moving_centered = transform_streamlines(moving, mat)
            static_mat = np.eye(4)
            moving_mat = mat

        self.metric.setup(static_centered, moving_centered)

        distance = self.metric.distance

        if self.method == 'Powell':

            if self.options is None:
                self.options = {'xtol': 1e-6, 'ftol': 1e-6, 'maxiter': 1e6}

            opt = Optimizer(distance, self.x0.tolist(),
                            method=self.method, options=self.options,
                            evolution=self.evolution)

        if self.method == 'L-BFGS-B':

            if self.options is None:
                self.options = {'maxcor': 10, 'ftol': 1e-7,
                                'gtol': 1e-5, 'eps': 1e-8,
                                'maxiter': 100}

            opt = Optimizer(distance, self.x0.tolist(),
                            method=self.method,
                            bounds=self.bounds, options=self.options,
                            evolution=self.evolution)

        if self.verbose:
            opt.print_summary()

        opt_mat = compose_matrix44(opt.xopt)

        mat = compose_transformations(moving_mat, opt_mat, static_mat)

        mat_history = []

        if opt.evolution is not None:
            for vecs in opt.evolution:
                mat_history.append(
                    compose_transformations(moving_mat,
                                            compose_matrix44(vecs),
                                            static_mat))

        srm = StreamlineRegistrationMap(mat, opt.xopt, opt.fopt,
                                        mat_history, opt.nfev, opt.nit)
        del opt
        return srm
def show_results(streamlines, vol, affine, world_coords=True, opacity=0.6):

    from dipy.viz import actor, window, widget

    shape = data.shape

    if not world_coords:
        from dipy.tracking.streamline import transform_streamlines
        streamlines = transform_streamlines(streamlines, np.linalg.inv(affine))

    ren = window.Renderer()
    if streamlines is not None:
        stream_actor = actor.line(streamlines)

    if not world_coords:
        image_actor = actor.slicer(vol, affine=np.eye(4))
    else:
        image_actor = actor.slicer(vol, affine)

    slicer_opacity = opacity  #.6
    image_actor.opacity(slicer_opacity)

    if streamlines is not None:
        ren.add(stream_actor)
    ren.add(image_actor)

    show_m = window.ShowManager(ren, size=(1200, 900))
    show_m.initialize()

    def change_slice(obj, event):
        z = int(np.round(obj.get_value()))
        image_actor.display_extent(0, shape[0] - 1, 0, shape[1] - 1, z, z)

    slider = widget.slider(show_m.iren,
                           show_m.ren,
                           callback=change_slice,
                           min_value=0,
                           max_value=shape[2] - 1,
                           value=shape[2] / 2,
                           label="Move slice",
                           right_normalized_pos=(.98, 0.6),
                           size=(120, 0),
                           label_format="%0.lf",
                           color=(1., 1., 1.),
                           selected_color=(0.86, 0.33, 1.))

    global size
    size = ren.GetSize()

    def win_callback(obj, event):
        global size
        if size != obj.GetSize():

            slider.place(ren)
            size = obj.GetSize()

    show_m.initialize()

    show_m.add_window_callback(win_callback)
    show_m.render()
    show_m.start()

    # ren.zoom(1.5)
    # ren.reset_clipping_range()

    # window.record(ren, out_path='bundles_and_a_slice.png', size=(1200, 900),
    #               reset_camera=False)

    del show_m
Пример #41
0
    else:
        return float(b) / a


args = sys.argv[1:]
ref_img_in = args[0]
file_in = args[1]
file_out = args[2]

ref_img = nib.load(ref_img_in)
#ref_img_shape = ref_img.get_data().shape
ref_img_shape = ref_img.header.get_data_shape()

streams, hdr = trackvis.read(file_in)
streamlines = [s[0] for s in streams]
streamlines = transform_streamlines(streamlines, np.linalg.inv(ref_img.affine))

mask_start = np.zeros(ref_img_shape)
mask_end = np.zeros(ref_img_shape)

if len(streamlines) > 0:

    startpoints = []
    endpoints = []
    for streamline in streamlines:
        startpoints.append(streamline[0])
        endpoints.append(streamline[-1])

    points = np.array(startpoints + endpoints)
    # Subsample points to make clustering faster
    #  Has to be at least 50k to work properly for very big tracts like CC (otherwise points too far apart for DBSCAN)
Пример #42
0
from src.tractography.viz import draw_bundles
from os import listdir  # , mkdir
from os.path import isfile  # , isdir
from src.tractography.io import read_ply
import argparse
from dipy.align.streamlinear import compose_matrix44
from dipy.tracking.streamline import transform_streamlines

parser = argparse.ArgumentParser(description='Input argument parser.')
parser.add_argument('-f', type=str, help='location of files')
args = parser.parse_args()
data_path = '../data/132118/'
#data_path = args.f
files = [
    data_path + f for f in listdir(data_path)
    if isfile(data_path + f) and f.endswith('.ply')
]

mat = compose_matrix44([0, 0, 0, 0, 90, 90])
brain = []
for name in files:
    brain.append(transform_streamlines(read_ply(name), mat))
draw_bundles(brain, rotate=True)
"""
data1 = read_ply('../data/132118/m_ex_atr-left_shore.ply')
data2 = read_ply('../data/132118/m_ex_atr-right_shore.ply')
draw_bundles([data1,data2])
"""
Пример #43
0
def evaluate_along_streamlines(scalar_img,
                               streamlines,
                               beginnings,
                               nr_points,
                               dilate=0,
                               predicted_peaks=None,
                               affine=None):
    # Runtime:
    # - default:                2.7s (test),    56s (all),      10s (test 4 bundles, 100 points)
    # - map_coordinate order 1: 1.9s (test),    26s (all),       6s (test 4 bundles, 100 points)
    # - map_coordinate order 3: 2.2s (test),    33s (all),
    # - values_from_volume:     2.5s (test),    43s (all),
    # - AFQ:                      ?s (test),     ?s (all),      85s  (test 4 bundles, 100 points)
    # => AFQ a lot slower than others

    streamlines = list(
        transform_streamlines(streamlines, np.linalg.inv(affine)))

    for i in range(dilate):
        beginnings = binary_dilation(beginnings)
    beginnings = beginnings.astype(np.uint8)
    streamlines = _orient_to_same_start_region(streamlines, beginnings)
    if predicted_peaks is not None:
        # scalar img can also be orig peaks
        best_orig_peaks = fiber_utils.get_best_original_peaks(
            predicted_peaks, scalar_img, peak_len_thr=0.00001)
        scalar_img = np.linalg.norm(best_orig_peaks, axis=-1)

    algorithm = "distance_map"  # equal_dist | distance_map | cutting_plane | afq

    if algorithm == "equal_dist":
        ### Sampling ###
        streamlines = fiber_utils.resample_fibers(streamlines,
                                                  nb_points=nr_points)
        values = map_coordinates(scalar_img, np.array(streamlines).T, order=1)
        ### Aggregation ###
        values_mean = np.array(values).mean(axis=1)
        values_std = np.array(values).std(axis=1)
        return values_mean, values_std

    if algorithm == "distance_map":  # cKDTree

        ### Sampling ###
        streamlines = fiber_utils.resample_fibers(streamlines,
                                                  nb_points=nr_points)
        values = map_coordinates(scalar_img, np.array(streamlines).T, order=1)

        ### Aggregating by cKDTree approach ###
        metric = AveragePointwiseEuclideanMetric()
        qb = QuickBundles(threshold=100., metric=metric)
        clusters = qb.cluster(streamlines)
        centroids = Streamlines(clusters.centroids)
        if len(centroids) > 1:
            print("WARNING: number clusters > 1 ({})".format(len(centroids)))
        _, segment_idxs = cKDTree(centroids.data, 1,
                                  copy_data=True).query(streamlines,
                                                        k=1)  # (2000, 100)

        values_t = np.array(values).T  # (2000, 100)

        # If we want to take weighted mean like in AFQ:
        # weights = dsa.gaussian_weights(Streamlines(streamlines))
        # values_t = weights * values_t
        # return np.sum(values_t, 0), None

        results_dict = defaultdict(list)
        for idx, sl in enumerate(values_t):
            for jdx, seg in enumerate(sl):
                results_dict[segment_idxs[idx, jdx]].append(seg)

        if len(results_dict.keys()) < nr_points:
            print(
                "WARNING: found less than required points. Filling up with centroid values."
            )
            centroid_values = map_coordinates(scalar_img,
                                              np.array([centroids[0]]).T,
                                              order=1)
            for i in range(nr_points):
                if len(results_dict[i]) == 0:
                    results_dict[i].append(np.array(centroid_values).T[0, i])

        results_mean = []
        results_std = []
        for key in sorted(results_dict.keys()):
            value = results_dict[key]
            if len(value) > 0:
                results_mean.append(np.array(value).mean())
                results_std.append(np.array(value).std())
            else:
                print("WARNING: empty segment")
                results_mean.append(0)
                results_std.append(0)
        return results_mean, results_std

    elif algorithm == "cutting_plane":
        # This will resample all streamline to have equally distant points (resulting in a different number of points
        # in each streamline). Then the "middle" of the tract will be estimated taking the middle element of the
        # centroid (estimated with QuickBundles). Then each streamline the point closest to the "middle" will be
        # calculated and points will be indexed for each streamline starting from the middle. Then averaging across
        # all streamlines will be done by taking the mean for points with same indices.

        ### Sampling ###
        streamlines = fiber_utils.resample_to_same_distance(
            streamlines, max_nr_points=nr_points)
        # map_coordinates does not allow streamlines with different lengths -> use values_from_volume
        values = np.array(
            values_from_volume(scalar_img, streamlines, affine=np.eye(4))).T

        ### Aggregating by Cutting Plane approach ###
        # Resample to all fibers having same number of points -> needed for QuickBundles
        streamlines_resamp = fiber_utils.resample_fibers(streamlines,
                                                         nb_points=nr_points)
        metric = AveragePointwiseEuclideanMetric()
        qb = QuickBundles(threshold=100., metric=metric)
        clusters = qb.cluster(streamlines_resamp)
        centroids = Streamlines(clusters.centroids)

        # index of the middle cluster
        middle_idx = int(nr_points / 2)
        middle_point = centroids[0][middle_idx]
        # For each streamline get idx for the point which is closest to the middle
        segment_idxs = fiber_utils.get_idxs_of_closest_points(
            streamlines, middle_point)

        # Align along the middle and assign indices
        segment_idxs_eqlen = []
        base_idx = 1000  # use higher index to avoid negative numbers for area below middle
        for idx, sl in enumerate(streamlines):
            sl_middle_pos = segment_idxs[idx]
            before_elems = sl_middle_pos
            after_elems = len(sl) - sl_middle_pos
            # indices for one streamline e.g. [998, 999, 1000, 1001, 1002, 1003]; 1000 is middle
            r = range((base_idx - before_elems), (base_idx + after_elems))
            segment_idxs_eqlen.append(r)
        segment_idxs = segment_idxs_eqlen

        # Calcuate maximum number of indices to not result in more indices than nr_points.
        # (this could be case if one streamline is very off-center and therefore has a lot of points only on one
        # side. In this case the values too far out of this streamline will be cut off).
        max_idx = base_idx + int(nr_points / 2)
        min_idx = base_idx - int(nr_points / 2)

        # Group by segment indices
        results_dict = defaultdict(list)
        for idx, sl in enumerate(values):
            for jdx, seg in enumerate(sl):
                current_idx = segment_idxs[idx][jdx]
                if current_idx >= min_idx and current_idx < max_idx:
                    results_dict[current_idx].append(seg)

        # If values missing fill up with centroid values
        if len(results_dict.keys()) < nr_points:
            print(
                "WARNING: found less than required points. Filling up with centroid values."
            )
            centroid_sl = [centroids[0]]
            centroid_sl = np.array(centroid_sl).T
            centroid_values = map_coordinates(scalar_img, centroid_sl, order=1)
            for idx, seg_idx in enumerate(range(min_idx, max_idx)):
                if len(results_dict[seg_idx]) == 0:
                    results_dict[seg_idx].append(
                        np.array(centroid_values).T[0, idx])

        # Aggregate by mean
        results_mean = []
        results_std = []
        for key in sorted(results_dict.keys()):
            value = results_dict[key]
            if len(value) > 0:
                results_mean.append(np.array(value).mean())
                results_std.append(np.array(value).std())
            else:
                print("WARNING: empty segment")
                results_mean.append(0)
                results_std.append(0)
        return results_mean, results_std

    elif algorithm == "afq":
        ### sampling + aggregation ###
        streamlines = fiber_utils.resample_fibers(streamlines,
                                                  nb_points=nr_points)
        streamlines = Streamlines(streamlines)
        weights = dsa.gaussian_weights(streamlines)
        results_mean = dsa.afq_profile(scalar_img,
                                       streamlines,
                                       affine=np.eye(4),
                                       weights=weights)
        results_std = np.zeros(nr_points)
        return results_mean, results_std
                    with open(picklepath_connectome, 'rb') as f:
                        M = pickle.load(f)
                if os.path.exists(grouping_xlsxpath):
                    grouping = extract_grouping(grouping_xlsxpath,
                                                index_to_struct,
                                                None,
                                                verbose=verbose)
                else:
                    if allow_preprun:
                        labelmask, labelaffine, labeloutpath, index_to_struct = getlabeltypemask(
                            label_folder,
                            'MDT',
                            ROI_legends,
                            labeltype=labeltype,
                            verbose=verbose)
                        streamlines_world = transform_streamlines(
                            trkdata.streamlines, np.linalg.inv(labelaffine))

                        #M, grouping = connectivity_matrix_func(trkdata.streamlines, function_processes, labelmask,
                        #                                       symmetric=True, mapping_as_streamlines=False,
                        #                                       affine_streams=trkdata.space_attributes[0],
                        #                                       inclusive=inclusive)
                        M, grouping = connectivity_matrix_func(
                            streamlines_world,
                            np.eye(4),
                            labelmask,
                            inclusive=inclusive,
                            symmetric=symmetric,
                            return_mapping=True,
                            mapping_as_streamlines=False,
                            reference_weighting=None,
                            volume_weighting=False,
Пример #45
0
def direct_streamline_norm(streams, fa_path, ap_path, dir_path, track_type,
                           target_samples, conn_model, network, node_size,
                           dens_thresh, ID, roi, min_span_tree, disp_filt,
                           parc, prune, atlas, labels_im_file, uatlas, labels,
                           coords, norm, binary, atlas_mni, basedir_path,
                           curv_thr_list, step_list, directget, min_length,
                           error_margin, t1_aligned_mni):
    """
    A Function to perform normalization of streamlines tracked in native
    diffusion space to an MNI-space template.

    Parameters
    ----------
    streams : str
        File path to save streamline array sequence in .trk format.
    fa_path : str
        File path to FA Nifti1Image.
    ap_path : str
        File path to the anisotropic power Nifti1Image.
    dir_path : str
        Path to directory containing subject derivative data for a given
        pynets run.
    track_type : str
        Tracking algorithm used (e.g. 'local' or 'particle').
    target_samples : int
        Total number of streamline samples specified to generate streams.
    conn_model : str
        Connectivity reconstruction method (e.g. 'csa', 'tensor', 'csd').
    network : str
        Resting-state network based on Yeo-7 and Yeo-17 naming (e.g. 'Default')
        used to filter nodes in the study of brain subgraphs.
    node_size : int
        Spherical centroid node size in the case that coordinate-based
        centroids are used as ROI's for tracking.
    dens_thresh : bool
        Indicates whether a target graph density is to be used as the basis for
        thresholding.
    ID : str
        A subject id or other unique identifier.
    roi : str
        File path to binarized/boolean region-of-interest Nifti1Image file.
    min_span_tree : bool
        Indicates whether local thresholding from the Minimum Spanning Tree
        should be used.
    disp_filt : bool
        Indicates whether local thresholding using a disparity filter and
        'backbone network' should be used.
    parc : bool
        Indicates whether to use parcels instead of coordinates as ROI nodes.
    prune : bool
        Indicates whether to prune final graph of disconnected nodes/isolates.
    atlas : str
        Name of atlas parcellation used.
    labels_im_file : str
        File path to atlas parcellation Nifti1Image aligned to dwi space.
    uatlas : str
        File path to atlas parcellation Nifti1Image in MNI template space.
    labels : list
        List of string labels corresponding to graph nodes.
    coords : list
        List of (x, y, z) tuples corresponding to a coordinate atlas used or
        which represent the center-of-mass of each parcellation node.
    norm : int
        Indicates method of normalizing resulting graph.
    binary : bool
        Indicates whether to binarize resulting graph edges to form an
        unweighted graph.
    atlas_mni : str
        File path to atlas parcellation Nifti1Image in T1w-warped MNI space.
    basedir_path : str
        Path to directory to output direct-streamline normalized temp files
        and outputs.
    curv_thr_list : list
        List of integer curvature thresholds used to perform ensemble tracking.
    step_list : list
        List of float step-sizes used to perform ensemble tracking.
    directget : str
        The statistical approach to tracking. Options are: det (deterministic),
        closest (clos), boot (bootstrapped), and prob (probabilistic).
    min_length : int
        Minimum fiber length threshold in mm to restrict tracking.
    t1_aligned_mni : str
        File path to the T1w Nifti1Image in template MNI space.

    Returns
    -------
    streams_warp : str
        File path to normalized streamline array sequence in .trk format.
    dir_path : str
        Path to directory containing subject derivative data for a given
        pynets run.
    track_type : str
        Tracking algorithm used (e.g. 'local' or 'particle').
    target_samples : int
        Total number of streamline samples specified to generate streams.
    conn_model : str
        Connectivity reconstruction method (e.g. 'csa', 'tensor', 'csd').
    network : str
        Resting-state network based on Yeo-7 and Yeo-17 naming (e.g. 'Default')
        used to filter nodes in the study of brain subgraphs.
    node_size : int
        Spherical centroid node size in the case that coordinate-based
        centroids are used as ROI's for tracking.
    dens_thresh : bool
        Indicates whether a target graph density is to be used as the basis for
        thresholding.
    ID : str
        A subject id or other unique identifier.
    roi : str
        File path to binarized/boolean region-of-interest Nifti1Image file.
    min_span_tree : bool
        Indicates whether local thresholding from the Minimum Spanning Tree
        should be used.
    disp_filt : bool
        Indicates whether local thresholding using a disparity filter and
        'backbone network' should be used.
    parc : bool
        Indicates whether to use parcels instead of coordinates as ROI nodes.
    prune : bool
        Indicates whether to prune final graph of disconnected nodes/isolates.
    atlas : str
        Name of atlas parcellation used.
    uatlas : str
        File path to atlas parcellation Nifti1Image in MNI template space.
    labels : list
        List of string labels corresponding to graph nodes.
    coords : list
        List of (x, y, z) tuples corresponding to a coordinate atlas used or
        which represent the center-of-mass of each parcellation node.
    norm : int
        Indicates method of normalizing resulting graph.
    binary : bool
        Indicates whether to binarize resulting graph edges to form an
        unweighted graph.
    atlas_mni : str
        File path to atlas parcellation Nifti1Image in T1w-warped MNI space.
    directget : str
        The statistical approach to tracking. Options are: det
        (deterministic), closest (clos), boot (bootstrapped),
        and prob (probabilistic).
    warped_fa : str
        File path to MNI-space warped FA Nifti1Image.
    min_length : int
        Minimum fiber length threshold in mm to restrict tracking.

    References
    ----------
    .. [1] Greene, C., Cieslak, M., & Grafton, S. T. (2017). Effect of
      different spatial normalization approaches on tractography and structural
      brain networks. Network Neuroscience, 1-19.
    """
    import sys
    import gc
    from dipy.tracking.streamline import transform_streamlines
    from pynets.registration import reg_utils as regutils
    # from pynets.plotting import plot_gen
    import pkg_resources
    import yaml
    import os.path as op
    from pynets.registration.reg_utils import vdc
    from nilearn.image import resample_to_img
    from dipy.io.streamline import load_tractogram
    from dipy.tracking import utils
    from dipy.tracking._utils import _mapping_to_voxel
    from dipy.io.stateful_tractogram import Space, StatefulTractogram, Origin
    from dipy.io.streamline import save_tractogram

    # from pynets.core.utils import missing_elements

    with open(pkg_resources.resource_filename("pynets", "runconfig.yaml"),
              "r") as stream:
        try:
            hardcoded_params = yaml.load(stream)
            run_dsn = hardcoded_params['tracking']["DSN"][0]
        except FileNotFoundError as e:
            import sys
            print(e, "Failed to parse runconfig.yaml")
            exit(1)

    stream.close()

    if run_dsn is True:
        dsn_dir = f"{basedir_path}/dmri_reg/DSN"
        if not op.isdir(dsn_dir):
            os.mkdir(dsn_dir)

        namer_dir = f"{dir_path}/tractography"
        if not op.isdir(namer_dir):
            os.mkdir(namer_dir)

        atlas_img = nib.load(labels_im_file)

        # Run SyN and normalize streamlines
        fa_img = nib.load(fa_path)
        vox_size = fa_img.header.get_zooms()[0]
        template_path = pkg_resources.resource_filename(
            "pynets", f"templates/FA_{int(vox_size)}mm.nii.gz")

        if sys.platform.startswith('win') is False:
            try:
                template_img = nib.load(template_path)
            except indexed_gzip.ZranError as e:
                print(
                    e, f"\nCannot load FA template. Do you have git-lfs "
                    f"installed?")
                sys.exit(1)
        else:
            try:
                template_img = nib.load(template_path)
            except ImportError as e:
                print(
                    e, f"\nCannot load FA template. Do you have git-lfs "
                    f"installed?")
                sys.exit(1)

        uatlas_mni_img = nib.load(atlas_mni)
        t1_aligned_mni_img = nib.load(t1_aligned_mni)
        brain_mask = np.asarray(t1_aligned_mni_img.dataobj).astype("bool")

        streams_mni = "%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s" % (
            namer_dir,
            "/streamlines_mni_",
            "%s" % (network + "_" if network is not None else ""),
            "%s" %
            (op.basename(roi).split(".")[0] + "_" if roi is not None else ""),
            conn_model,
            "_",
            target_samples,
            "%s" % ("%s%s" % ("_" + str(node_size), "mm_") if
                    ((node_size != "parc") and
                     (node_size is not None)) else "_"),
            "curv",
            str(curv_thr_list).replace(", ", "_"),
            "step",
            str(step_list).replace(", ", "_"),
            "tracktype-",
            track_type,
            "_directget-",
            directget,
            "_minlength-",
            min_length,
            "_tol-",
            error_margin,
            ".trk",
        )

        density_mni = "%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s" % (
            namer_dir,
            "/density_map_mni_",
            "%s" % (network + "_" if network is not None else ""),
            "%s" %
            (op.basename(roi).split(".")[0] + "_" if roi is not None else ""),
            conn_model,
            "_",
            target_samples,
            "%s" % ("%s%s" % ("_" + str(node_size), "mm_") if
                    ((node_size != "parc") and
                     (node_size is not None)) else "_"),
            "curv",
            str(curv_thr_list).replace(", ", "_"),
            "step",
            str(step_list).replace(", ", "_"),
            "tracktype-",
            track_type,
            "_directget-",
            directget,
            "_minlength-",
            min_length,
            "_tol-",
            error_margin,
            ".nii.gz",
        )

        # streams_warp_png = '/tmp/dsn.png'

        # SyN FA->Template
        [mapping, affine_map,
         warped_fa] = regutils.wm_syn(template_path, fa_path, t1_aligned_mni,
                                      ap_path, dsn_dir)

        tractogram = load_tractogram(
            streams,
            fa_img,
            to_origin=Origin.NIFTI,
            to_space=Space.VOXMM,
            bbox_valid_check=False,
        )

        fa_img.uncache()
        streamlines = tractogram.streamlines
        warped_fa_img = nib.load(warped_fa)
        warped_fa_affine = warped_fa_img.affine
        warped_fa_shape = warped_fa_img.shape

        streams_in_curr_grid = transform_streamlines(streamlines,
                                                     warped_fa_affine)

        # Create isocenter mapping where we anchor the origin transformation
        # affine to the corner of the FOV by scaling x, y, z offsets according
        # to a multiplicative van der Corput sequence with a base value equal
        # to the voxel resolution
        [x_mul, y_mul, z_mul] = [vdc(i, vox_size) for i in range(1, 4)]

        ref_grid_aff = vox_size * np.eye(4)
        ref_grid_aff[3][3] = 1

        streams_final_filt = []
        i = 0
        # Test for various types of voxel-grid configurations
        combs = [(-x_mul, -y_mul, -z_mul), (-x_mul, -y_mul, z_mul),
                 (-x_mul, y_mul, -z_mul), (x_mul, -y_mul, -z_mul),
                 (x_mul, y_mul, z_mul)]
        while len(streams_final_filt) / len(streams_in_curr_grid) < 0.90:
            print(f"Warping streamlines to MNI space. Attempt {i}...")
            print(len(streams_final_filt) / len(streams_in_curr_grid))
            adjusted_affine = affine_map.affine.copy()
            if i > len(combs) - 1:
                raise ValueError('DSN failed. Header orientation '
                                 'information may be corrupted. '
                                 'Is your dataset oblique?')

            adjusted_affine[0][3] = adjusted_affine[0][3] * combs[i][0]
            adjusted_affine[1][3] = adjusted_affine[1][3] * combs[i][1]
            adjusted_affine[2][3] = adjusted_affine[2][3] * combs[i][2]

            streams_final_filt = regutils.warp_streamlines(
                adjusted_affine, ref_grid_aff, mapping, warped_fa_img,
                streams_in_curr_grid, brain_mask)

            i += 1

        # Remove streamlines with negative voxel indices
        lin_T, offset = _mapping_to_voxel(np.eye(4))
        streams_final_filt_final = []
        for sl in streams_final_filt:
            inds = np.dot(sl, lin_T)
            inds += offset
            if not inds.min().round(decimals=6) < 0:
                streams_final_filt_final.append(sl)

        # Save streamlines
        stf = StatefulTractogram(
            streams_final_filt_final,
            reference=uatlas_mni_img,
            space=Space.VOXMM,
            origin=Origin.NIFTI,
        )
        stf.remove_invalid_streamlines()
        streams_final_filt_final = stf.streamlines
        save_tractogram(stf, streams_mni, bbox_valid_check=True)
        warped_fa_img.uncache()

        # DSN QC plotting
        # plot_gen.show_template_bundles(streams_final_filt_final, atlas_mni,
        # streams_warp_png) plot_gen.show_template_bundles(streamlines,
        # fa_path, streams_warp_png)

        # Create and save MNI density map
        nib.save(
            nib.Nifti1Image(
                utils.density_map(streams_final_filt_final,
                                  affine=np.eye(4),
                                  vol_dims=warped_fa_shape),
                warped_fa_affine,
            ),
            density_mni,
        )

        # Map parcellation from native space back to MNI-space and create an
        # 'uncertainty-union' parcellation with original mni-space uatlas
        warped_uatlas = affine_map.transform_inverse(
            mapping.transform(
                np.asarray(atlas_img.dataobj).astype("int"),
                interpolation="nearestneighbour",
            ),
            interp="nearest",
        )
        atlas_img.uncache()
        warped_uatlas_img_res_data = np.asarray(
            resample_to_img(
                nib.Nifti1Image(warped_uatlas, affine=warped_fa_affine),
                uatlas_mni_img,
                interpolation="nearest",
                clip=False,
            ).dataobj)
        uatlas_mni_data = np.asarray(uatlas_mni_img.dataobj)
        uatlas_mni_img.uncache()
        overlap_mask = np.invert(
            warped_uatlas_img_res_data.astype("bool") *
            uatlas_mni_data.astype("bool"))
        os.makedirs(f"{dir_path}/parcellations", exist_ok=True)
        atlas_mni = f"{dir_path}/parcellations/" \
                    f"{op.basename(uatlas).split('.nii')[0]}_liberal.nii.gz"

        nib.save(
            nib.Nifti1Image(
                warped_uatlas_img_res_data * overlap_mask.astype("int") +
                uatlas_mni_data * overlap_mask.astype("int") +
                np.invert(overlap_mask).astype("int") *
                warped_uatlas_img_res_data.astype("int"),
                affine=uatlas_mni_img.affine,
            ),
            atlas_mni,
        )

        del (
            tractogram,
            streamlines,
            warped_uatlas_img_res_data,
            uatlas_mni_data,
            overlap_mask,
            stf,
            streams_final_filt_final,
            streams_final_filt,
            streams_in_curr_grid,
            brain_mask,
        )

        gc.collect()

        assert len(coords) == len(labels)

    else:
        print(
            "Skipping Direct Streamline Normalization (DSN). Will proceed to "
            "define fiber connectivity in native diffusion space...")
        streams_mni = streams
        warped_fa = fa_path
        atlas_mni = labels_im_file

    return (streams_mni, dir_path, track_type, target_samples, conn_model,
            network, node_size, dens_thresh, ID, roi, min_span_tree, disp_filt,
            parc, prune, atlas, uatlas, labels, coords, norm, binary,
            atlas_mni, directget, warped_fa, min_length, error_margin)
Пример #46
0
def save_vtk_streamlines(streamlines, filename, to_lps=True, binary=False):
    """Save streamlines as vtk polydata to a supported format file.

    File formats can be VTK, FIB

    Parameters
    ----------
    streamlines : list
        list of 2D arrays or ArraySequence
    filename : string
        output filename (.vtk or .fib)
    to_lps : bool
        Default to True, will follow the vtk file convention for streamlines
        Will be supported by MITKDiffusion and MI-Brain
    binary : bool
        save the file as binary
    """
    if to_lps:
        # ras (mm) to lps (mm)
        to_lps = np.eye(4)
        to_lps[0, 0] = -1
        to_lps[1, 1] = -1
        streamlines = transform_streamlines(streamlines, to_lps)

    # Get the 3d points_array
    nb_lines = len(streamlines)
    points_array = np.vstack(streamlines)

    # Get lines_array in vtk input format
    lines_array = []
    current_position = 0
    for i in range(nb_lines):
        current_len = len(streamlines[i])

        end_position = current_position + current_len
        lines_array.append(current_len)
        lines_array.extend(range(current_position, end_position))
        current_position = end_position

    # Set Points to vtk array format
    vtk_points = vtk.vtkPoints()
    vtk_points.SetData(
        ns.numpy_to_vtk(points_array.astype(np.float32), deep=True))

    # Set Lines to vtk array format
    vtk_lines = vtk.vtkCellArray()
    vtk_lines.SetNumberOfCells(nb_lines)
    vtk_lines.GetData().DeepCopy(
        ns.numpy_to_vtk(np.array(lines_array), deep=True))

    # Create the poly_data
    polydata = vtk.vtkPolyData()
    polydata.SetPoints(vtk_points)
    polydata.SetLines(vtk_lines)

    writer = vtk.vtkPolyDataWriter()
    writer.SetFileName(filename)
    writer = utils.set_input(writer, polydata)

    if binary:
        writer.SetFileTypeToBinary()

    writer.Update()
    writer.Write()
Пример #47
0
def bundle_analysis(model_bundle_folder, bundle_folder, orig_bundle_folder,
                    metric_folder, group, subject, no_disks=100,
                    out_dir=''):
    """
    Applies statistical analysis on bundles and saves the results
    in a directory specified by ``out_dir``.

    Parameters
    ----------
    model_bundle_folder : string
        Path to the input model bundle files. This path may contain
        wildcards to process multiple inputs at once.
    bundle_folder : string
        Path to the input bundle files in common space. This path may
        contain wildcards to process multiple inputs at once.
    orig_folder : string
        Path to the input bundle files in native space. This path may
        contain wildcards to process multiple inputs at once.
    metric_folder : string
        Path to the input dti metric or/and peak files. It will be used as
        metric for statistical analysis of bundles.
    group : string
        what group subject belongs to e.g. control or patient
    subject : string
        subject id e.g. 10001
    no_disks : integer, optional
        Number of disks used for dividing bundle into disks. (Default 100)
    out_dir : string, optional
        Output directory (default input file directory)

    References
    ----------
    .. [Chandio19] Chandio, B.Q., S. Koudoro, D. Reagan, J. Harezlak,
    E. Garyfallidis, Bundle Analytics: a computational and statistical
    analyses framework for tractometric studies, Proceedings of:
    International Society of Magnetic Resonance in Medicine (ISMRM),
    Montreal, Canada, 2019.

    """

    dt = dict()

    mb = os.listdir(model_bundle_folder)
    mb.sort()
    bd = os.listdir(bundle_folder)
    bd.sort()
    org_bd = os.listdir(orig_bundle_folder)
    org_bd.sort()
    n = len(org_bd)

    for io in range(n):
        mbundles, _ = load_trk(os.path.join(model_bundle_folder, mb[io]))
        bundles, _ = load_trk(os.path.join(bundle_folder, bd[io]))
        orig_bundles, _ = load_trk(os.path.join(orig_bundle_folder,
                                   org_bd[io]))

        mbundle_streamlines = set_number_of_points(mbundles,
                                                   nb_points=no_disks)

        metric = AveragePointwiseEuclideanMetric()
        qb = QuickBundles(threshold=25., metric=metric)
        clusters = qb.cluster(mbundle_streamlines)
        centroids = Streamlines(clusters.centroids)

        print('Number of centroids ', len(centroids.data))
        print('Model bundle ', mb[io])
        print('Number of streamlines in bundle in common space ',
              len(bundles))
        print('Number of streamlines in bundle in original space ',
              len(orig_bundles))

        _, indx = cKDTree(centroids.data, 1,
                          copy_data=True).query(bundles.data, k=1)

        metric_files_names = os.listdir(metric_folder)
        _, affine = load_nifti(os.path.join(metric_folder, "fa.nii.gz"))

        affine_r = np.linalg.inv(affine)
        transformed_orig_bundles = transform_streamlines(orig_bundles,
                                                         affine_r)

        for mn in range(0, len(metric_files_names)):

            ind = np.array(indx)
            fm = metric_files_names[mn][:2]
            bm = mb[io][:-4]
            dt = dict()
            metric_name = os.path.join(metric_folder,
                                       metric_files_names[mn])

            if metric_files_names[mn][2:] == '.nii.gz':
                metric, _ = load_nifti(metric_name)

                dti_measures(transformed_orig_bundles, metric, dt, fm,
                             bm, subject, group, ind, out_dir)

            else:
                fm = metric_files_names[mn][:3]
                metric = load_peaks(metric_name)
                peak_values(bundles, metric, dt, fm, bm, subject, group,
                            ind, out_dir)
Пример #48
0
def fiber_simple_3d_show_advanced(img,
                                  streamlines,
                                  colors=None,
                                  linewidth=1,
                                  s='png',
                                  imgcolor=False):

    streamlines = streamlines
    data = img.get_data()
    shape = img.shape
    affine = img.affine
    """
    With our current design it is easy to decide in which space you want the
    streamlines and slices to appear. The default we have here is to appear in
    world coordinates (RAS 1mm).
    """

    world_coords = True
    """
    If we want to see the objects in native space we need to make sure that all
    objects which are currently in world coordinates are transformed back to
    native space using the inverse of the affine.
    """

    if not world_coords:
        from dipy.tracking.streamline import transform_streamlines
        streamlines = transform_streamlines(streamlines, np.linalg.inv(affine))
    """
    Now we create, a ``Renderer`` object and add the streamlines using the ``line``
    function and an image plane using the ``slice`` function.
    """

    ren = window.Renderer()
    stream_actor = actor.line(streamlines, colors=colors, linewidth=linewidth)
    """img colormap"""
    if imgcolor:
        lut = actor.colormap_lookup_table(scale_range=(0, 1),
                                          hue_range=(0, 1.),
                                          saturation_range=(0., 1.),
                                          value_range=(0., 1.))
    else:
        lut = None
    if not world_coords:
        image_actor_z = actor.slicer(data,
                                     affine=np.eye(4),
                                     lookup_colormap=lut)
    else:
        image_actor_z = actor.slicer(data, affine, lookup_colormap=lut)
    """
    We can also change also the opacity of the slicer.
    """

    slicer_opacity = 0.6
    image_actor_z.opacity(slicer_opacity)
    """
    We can add additonal slicers by copying the original and adjusting the
    ``display_extent``.
    """

    image_actor_x = image_actor_z.copy()
    image_actor_x.opacity(slicer_opacity)
    x_midpoint = int(np.round(shape[0] / 2))
    image_actor_x.display_extent(x_midpoint, x_midpoint, 0, shape[1] - 1, 0,
                                 shape[2] - 1)

    image_actor_y = image_actor_z.copy()
    image_actor_y.opacity(slicer_opacity)
    y_midpoint = int(np.round(shape[1] / 2))
    image_actor_y.display_extent(0, shape[0] - 1, y_midpoint, y_midpoint, 0,
                                 shape[2] - 1)
    """
    Connect the actors with the Renderer.
    """

    ren.add(stream_actor)
    ren.add(image_actor_z)
    ren.add(image_actor_x)
    ren.add(image_actor_y)
    """
    Now we would like to change the position of each ``image_actor`` using a
    slider. The sliders are widgets which require access to different areas of the
    visualization pipeline and therefore we don't recommend using them with
    ``show``. The more appropriate way is to use them with the ``ShowManager``
    object which allows accessing the pipeline in different areas. Here is how:
    """

    show_m = window.ShowManager(ren, size=(1200, 900))
    show_m.initialize()
    """
    After we have initialized the ``ShowManager`` we can go ahead and create
    sliders to move the slices and change their opacity.
    """

    line_slider_z = ui.LineSlider2D(min_value=0,
                                    max_value=shape[2] - 1,
                                    initial_value=shape[2] / 2,
                                    text_template="{value:.0f}",
                                    length=140)

    line_slider_x = ui.LineSlider2D(min_value=0,
                                    max_value=shape[0] - 1,
                                    initial_value=shape[0] / 2,
                                    text_template="{value:.0f}",
                                    length=140)

    line_slider_y = ui.LineSlider2D(min_value=0,
                                    max_value=shape[1] - 1,
                                    initial_value=shape[1] / 2,
                                    text_template="{value:.0f}",
                                    length=140)

    opacity_slider = ui.LineSlider2D(min_value=0.0,
                                     max_value=1.0,
                                     initial_value=slicer_opacity,
                                     length=140)
    """
    Now we will write callbacks for the sliders and register them.
    """
    def change_slice_z(i_ren, obj, slider):
        z = int(np.round(slider.value))
        image_actor_z.display_extent(0, shape[0] - 1, 0, shape[1] - 1, z, z)

    def change_slice_x(i_ren, obj, slider):
        x = int(np.round(slider.value))
        image_actor_x.display_extent(x, x, 0, shape[1] - 1, 0, shape[2] - 1)

    def change_slice_y(i_ren, obj, slider):
        y = int(np.round(slider.value))
        image_actor_y.display_extent(0, shape[0] - 1, y, y, 0, shape[2] - 1)

    def change_opacity(i_ren, obj, slider):
        slicer_opacity = slider.value
        image_actor_z.opacity(slicer_opacity)
        image_actor_x.opacity(slicer_opacity)
        image_actor_y.opacity(slicer_opacity)

    line_slider_z.add_callback(line_slider_z.slider_disk, "MouseMoveEvent",
                               change_slice_z)
    line_slider_x.add_callback(line_slider_x.slider_disk, "MouseMoveEvent",
                               change_slice_x)
    line_slider_y.add_callback(line_slider_y.slider_disk, "MouseMoveEvent",
                               change_slice_y)
    opacity_slider.add_callback(opacity_slider.slider_disk, "MouseMoveEvent",
                                change_opacity)
    """
    We'll also create text labels to identify the sliders.
    """

    def build_label(text):
        label = ui.TextBlock2D()
        label.message = text
        label.font_size = 18
        label.font_family = 'Arial'
        label.justification = 'left'
        label.bold = False
        label.italic = False
        label.shadow = False
        # label.actor.GetTextProperty().SetBackgroundColor(0, 0, 0)
        # label.actor.GetTextProperty().SetBackgroundOpacity(0.0)
        label.color = (1, 1, 1)

        return label

    line_slider_label_z = build_label(text="Z Slice")
    line_slider_label_x = build_label(text="X Slice")
    line_slider_label_y = build_label(text="Y Slice")
    opacity_slider_label = build_label(text="Opacity")
    """
    Now we will create a ``panel`` to contain the sliders and labels.
    """

    panel = ui.Panel2D(center=(1030, 120),
                       size=(300, 200),
                       color=(1, 1, 1),
                       opacity=0.1,
                       align="right")

    panel.add_element(line_slider_label_x, 'relative', (0.1, 0.75))
    panel.add_element(line_slider_x, 'relative', (0.65, 0.8))
    panel.add_element(line_slider_label_y, 'relative', (0.1, 0.55))
    panel.add_element(line_slider_y, 'relative', (0.65, 0.6))
    panel.add_element(line_slider_label_z, 'relative', (0.1, 0.35))
    panel.add_element(line_slider_z, 'relative', (0.65, 0.4))
    panel.add_element(opacity_slider_label, 'relative', (0.1, 0.15))
    panel.add_element(opacity_slider, 'relative', (0.65, 0.2))

    show_m.ren.add(panel)
    """
    Then, we can render all the widgets and everything else in the screen and
    start the interaction using ``show_m.start()``.


    However, if you change the window size, the panel will not update its position
    properly. The solution to this issue is to update the position of the panel
    using its ``re_align`` method every time the window size changes.
    """

    global size
    size = ren.GetSize()

    def win_callback(obj, event):
        global size
        if size != obj.GetSize():
            size_old = size
            size = obj.GetSize()
            size_change = [size[0] - size_old[0], 0]
            panel.re_align(size_change)

    show_m.initialize()
    """
    Finally, please set the following variable to ``True`` to interact with the
    datasets in 3D.
    """

    interactive = True  #False

    ren.zoom(1.5)
    ren.reset_clipping_range()

    if interactive:

        show_m.add_window_callback(win_callback)
        show_m.render()
        show_m.start()

    else:

        window.record(
            ren,
            out_path=
            '/home/brain/workingdir/data/dwi/hcp/preprocessed/response_dhollander/'
            '100408/result/result20vs45/cc_clustering_png1/100408lr15_%s.png' %
            s,
            size=(1200, 900),
            reset_camera=False)
    """
    .. figure:: bundles_and_3_slices.png
       :align: center

       A few bundles with interactive slicing.
    """

    del show_m
    """
Пример #49
0
def buan_bundle_profiles(model_bundle_folder,
                         bundle_folder,
                         orig_bundle_folder,
                         metric_folder,
                         group_id,
                         subject,
                         no_disks=100,
                         out_dir=''):
    """
    Applies statistical analysis on bundles and saves the results
    in a directory specified by ``out_dir``.

    Parameters
    ----------
    model_bundle_folder : string
        Path to the input model bundle files. This path may contain
        wildcards to process multiple inputs at once.
    bundle_folder : string
        Path to the input bundle files in common space. This path may
        contain wildcards to process multiple inputs at once.
    orig_folder : string
        Path to the input bundle files in native space. This path may
        contain wildcards to process multiple inputs at once.
    metric_folder : string
        Path to the input dti metric or/and peak files. It will be used as
        metric for statistical analysis of bundles.
    group_id : integer
        what group subject belongs to either 0 for control or 1 for patient
    subject : string
        subject id e.g. 10001
    no_disks : integer, optional
        Number of disks used for dividing bundle into disks. (Default 100)
    out_dir : string, optional
        Output directory (default input file directory)

    References
    ----------
    .. [Chandio2020] Chandio, B.Q., Risacher, S.L., Pestilli, F., Bullock, D.,
    Yeh, FC., Koudoro, S., Rokem, A., Harezlack, J., and Garyfallidis, E.
    Bundle analytics, a computational framework for investigating the
    shapes and profiles of brain pathways across populations.
    Sci Rep 10, 17149 (2020)

    """

    t = time()

    dt = dict()

    mb = glob(os.path.join(model_bundle_folder, "*.trk"))
    print(mb)

    mb.sort()

    bd = glob(os.path.join(bundle_folder, "*.trk"))

    bd.sort()
    print(bd)
    org_bd = glob(os.path.join(orig_bundle_folder, "*.trk"))
    org_bd.sort()
    print(org_bd)
    n = len(org_bd)
    n = len(mb)

    for io in range(n):

        mbundles = load_tractogram(mb[io],
                                   reference='same',
                                   bbox_valid_check=False).streamlines
        bundles = load_tractogram(bd[io],
                                  reference='same',
                                  bbox_valid_check=False).streamlines
        orig_bundles = load_tractogram(org_bd[io],
                                       reference='same',
                                       bbox_valid_check=False).streamlines

        if len(orig_bundles) > 5:

            indx = assignment_map(bundles, mbundles, no_disks)
            ind = np.array(indx)

            metric_files_names_dti = glob(
                os.path.join(metric_folder, "*.nii.gz"))

            metric_files_names_csa = glob(os.path.join(metric_folder,
                                                       "*.pam5"))

            _, affine = load_nifti(metric_files_names_dti[0])

            affine_r = np.linalg.inv(affine)
            transformed_orig_bundles = transform_streamlines(
                orig_bundles, affine_r)

            for mn in range(len(metric_files_names_dti)):

                ab = os.path.split(metric_files_names_dti[mn])
                metric_name = ab[1]

                fm = metric_name[:-7]
                bm = os.path.split(mb[io])[1][:-4]

                logging.info("bm = " + bm)

                dt = dict()

                logging.info("metric = " + metric_files_names_dti[mn])

                metric, _ = load_nifti(metric_files_names_dti[mn])

                anatomical_measures(transformed_orig_bundles, metric, dt, fm,
                                    bm, subject, group_id, ind, out_dir)

            for mn in range(len(metric_files_names_csa)):
                ab = os.path.split(metric_files_names_csa[mn])
                metric_name = ab[1]

                fm = metric_name[:-5]
                bm = os.path.split(mb[io])[1][:-4]

                logging.info("bm = " + bm)
                logging.info("metric = " + metric_files_names_csa[mn])
                dt = dict()
                metric = load_peaks(metric_files_names_csa[mn])

                peak_values(transformed_orig_bundles, metric, dt, fm, bm,
                            subject, group_id, ind, out_dir)

    print("total time taken in minutes = ", (-t + time()) / 60)
Пример #50
0
    def _register_model_to_neighb(self,
                                  slr_num_thread=1,
                                  select_model=1000,
                                  select_target=1000,
                                  slr_transform_type='scaling'):
        """
        Parameters
        ----------
        slr_num_thread : int
            Number of threads for SLR.
            Should remain 1 for nearly all use-case.
        select_model : int
            Maximum number of clusters to select from the model.
        select_target : int
            Maximum number of clusters to select from the neighborhood.
        slr_transform_type : str
            Define the transformation for the local SLR.
            [translation, rigid, similarity, scaling].

        Returns
        -------
        transf_neighbor : list
            The neighborhood clusters transformed into model space.
        """
        possible_slr_transform_type = {
            'translation': 0,
            'rigid': 1,
            'similarity': 2,
            'scaling': 3
        }
        static = select_random_set_of_streamlines(self.model_centroids,
                                                  select_model, self.rng)
        moving = select_random_set_of_streamlines(self.neighb_centroids,
                                                  select_target, self.rng)

        # Tuple 0,1,2 are the min & max bound in x,y,z for translation
        # Tuple 3,4,5 are the min & max bound in x,y,z for rotation
        # Tuple 6,7,8 are the min & max bound in x,y,z for scaling
        # For uniform scaling (similarity), tuple #6 is enough
        bounds_dof = [(-20, 20), (-20, 20), (-20, 20), (-10, 10), (-10, 10),
                      (-10, 10), (0.8, 1.2), (0.8, 1.2), (0.8, 1.2)]
        metric = BundleMinDistanceMetric(num_threads=slr_num_thread)
        slr_transform_type_id = possible_slr_transform_type[slr_transform_type]
        if slr_transform_type_id >= 0:
            init_transfo_dof = np.zeros(3)
            slr = StreamlineLinearRegistration(metric=metric,
                                               method="Powell",
                                               x0=init_transfo_dof,
                                               bounds=bounds_dof[:3],
                                               num_threads=slr_num_thread)
            slm = slr.optimize(static, moving)

        if slr_transform_type_id >= 1:
            init_transfo_dof = np.zeros(6)
            init_transfo_dof[:3] = slm.xopt

            slr = StreamlineLinearRegistration(metric=metric,
                                               x0=init_transfo_dof,
                                               bounds=bounds_dof[:6],
                                               num_threads=slr_num_thread)
            slm = slr.optimize(static, moving)

        if slr_transform_type_id >= 2:
            if slr_transform_type_id == 2:
                init_transfo_dof = np.zeros(7)
                init_transfo_dof[:6] = slm.xopt
                init_transfo_dof[6] = 1.

                slr = StreamlineLinearRegistration(metric=metric,
                                                   x0=init_transfo_dof,
                                                   bounds=bounds_dof[:7],
                                                   num_threads=slr_num_thread)
                slm = slr.optimize(static, moving)

            else:
                init_transfo_dof = np.zeros(9)
                init_transfo_dof[:6] = slm.xopt[:6]
                init_transfo_dof[6:] = np.array((slm.xopt[6], ) * 3)

                slr = StreamlineLinearRegistration(metric=metric,
                                                   x0=init_transfo_dof,
                                                   bounds=bounds_dof[:9],
                                                   num_threads=slr_num_thread)
                slm = slr.optimize(static, moving)
        self.model_centroids = transform_streamlines(self.model_centroids,
                                                     np.linalg.inv(slm.matrix))
def show_results(streamlines, vol, affine, world_coords=True, opacity=0.6):

    from dipy.viz import actor, window, widget
    import numpy as np
    shape = vol.shape

    if not world_coords:
        from dipy.tracking.streamline import transform_streamlines
        streamlines = transform_streamlines(streamlines, np.linalg.inv(affine))

    ren = window.Renderer()
    if streamlines is not None:
        stream_actor = actor.line(streamlines)

    if not world_coords:
        image_actor = actor.slicer(vol, affine=np.eye(4))
    else:
        image_actor = actor.slicer(vol, affine)

    slicer_opacity = opacity #.6
    image_actor.opacity(slicer_opacity)

    if streamlines is not None:
        ren.add(stream_actor)
    ren.add(image_actor)

    show_m = window.ShowManager(ren, size=(1200, 900))
    show_m.initialize()

    def change_slice(obj, event):
        z = int(np.round(obj.get_value()))
        image_actor.display_extent(0, shape[0] - 1,
                                   0, shape[1] - 1, z, z)

    slider = widget.slider(show_m.iren, show_m.ren,
                           callback=change_slice,
                           min_value=0,
                           max_value=shape[2] - 1,
                           value=shape[2] / 2,
                           label="Move slice",
                           right_normalized_pos=(.98, 0.6),
                           size=(120, 0), label_format="%0.lf",
                           color=(1., 1., 1.),
                           selected_color=(0.86, 0.33, 1.))

    global size
    size = ren.GetSize()

    def win_callback(obj, event):
        global size
        if size != obj.GetSize():

            slider.place(ren)
            size = obj.GetSize()

    show_m.initialize()

    show_m.add_window_callback(win_callback)
    show_m.render()
    show_m.start()

    # ren.zoom(1.5)
    # ren.reset_clipping_range()

    # window.record(ren, out_path='bundles_and_a_slice.png', size=(1200, 900),
    #               reset_camera=False)

    del show_m
Пример #52
0
def fiber_simple_3d_show(img, streamlines, world_coords=True, slicer_opacity=0.6):
    if not world_coords:
        from dipy.tracking.streamline import transform_streamlines
        streamlines = transform_streamlines(streamlines, np.linalg.inv(img.affine))

    # Renderer
    ren = window.Renderer()
    stream_actor = actor.line(streamlines)

    if not world_coords:
        image_actor = actor.slicer(img.get_data(), affine=np.eye(4))
    else:
        image_actor = actor.slicer(img.get_data(), img.affine)

    # opacity
    image_actor.opacity(slicer_opacity)

    # add some slice
    image_actor2 = image_actor.copy()
    image_actor2.opacity(slicer_opacity)
    # image_actor2.display()
    image_actor2.display(None, image_actor2.shape[1] / 2, None)
    image_actor3 = image_actor.copy()
    image_actor3.opacity(slicer_opacity)
    # image_actor3.display()
    image_actor3.display(image_actor3.shape[0] / 2, None, None)

    # connect the actors with the Render
    ren.add(stream_actor)
    ren.add(image_actor)
    ren.add(image_actor2)
    ren.add(image_actor3)

    # initial showmanager
    show_m = window.ShowManager(ren, size=(1200, 900))
    show_m.initialize()

    # change the position of the image_actor using a slider
    def change_slice(obj, event):
        z = int(np.round(obj.get_value()))
        image_actor.display_extent(0, img.shape[0] - 1, 0, img.shape[1] - 1, z, z)

    slicer = widget.slider(show_m.iren, show_m.ren, callback=change_slice, min_value=0, max_value=img.shape[2] - 1,
                           value=img.shape[2] / 2, label="Move slice",
                           right_normalized_pos=(.98, 0.6), size=(120, 0), label_format="%0.1f", color=(1., 1., 1.),
                           selected_color=(0.86, 0.33, 1.))

    # change the position of the image_actor using a slider
    def change_slice2(obj, event):
        y = int(np.round(obj.get_value()))
        image_actor2.display_extent(0, img.shape[0] - 1, y, y, 0, img.shape[2] - 1)

    slicer2 = widget.slider(show_m.iren, show_m.ren, callback=change_slice2, min_value=0, max_value=img.shape[1] - 1,
                            value=img.shape[1] / 2, label="Coronal slice",
                            right_normalized_pos=(.98, 0.3), size=(120, 0), label_format="%0.1f", color=(1., 1., 1.),
                            selected_color=(0.86, 0.33, 1.))

    # change the position of the image_actor using a slider
    def change_slice3(obj, event):
        x = int(np.round(obj.get_value()))
        image_actor3.display_extent(x, x, 0, img.shape[1] - 1, 0, img.shape[2] - 1)

    slicer3 = widget.slider(show_m.iren, show_m.ren, callback=change_slice3, min_value=0, max_value=img.shape[0] - 1,
                            value=img.shape[0] / 2, label="Sagittal slice",
                            right_normalized_pos=(.98, 0.9), size=(120, 0), label_format="%0.1f", color=(1., 1., 1.),
                            selected_color=(0.86, 0.33, 1.))

    # change window size, the slider will change
    global size
    size = ren.GetSize()

    def win_callback(obj, event):
        global size
        if size != obj.GetSize():
            slicer.place(ren)
            slicer2.place(ren)
            slicer3.place(ren)
            size = obj.GetSize()

    show_m.initialize()

    # interact with the available 3D and 2D objects
    show_m.add_window_callback(win_callback)
    show_m.render()
    show_m.start()

    ren.zoom(1.5)
    ren.reset_clipping_range()

    window.record(ren, out_path='/home/brain/workingdir/pyfat/pyfat/example/test_results/cc_clusters_test.png', size=(1200, 900), reset_camera=False)
    del show_m