Esempio n. 1
0
def visualize_pdf(pdf, gtab=None):
    signal = pdf
    s = signal.shape
    grid_s = math.ceil((signal.shape[-1])**(1 / 3))
    signal = signal.reshape(*s[:-1], grid_s, grid_s, grid_s)

    sphere = get_sphere('repulsion724')

    if gtab:
        dsmodel = DiffusionSpectrumModel(gtab)
        signal_xq = np.fft.fftn(signal, axes=[-3, -2, -1])
        dsfit = dsmodel.fit(signal_xq.reshape(*s[:2], -1))
        odfs = dsfit.odf(sphere)
    else:
        signal = signal.reshape(-1, grid_s, grid_s, grid_s)
        odfs = np.stack([
            get_odf(signal[vid], sphere) for vid in range(signal.shape[0])
        ]).reshape(*s[:-1], -1)

    if len(odfs.shape) == 3:
        odfs = odfs[:, :, np.newaxis]

    # visualize
    r = window.Scene()
    sfu = actor.odf_slicer(odfs, sphere=sphere, colormap='plasma', scale=0.5)
    sfu.display(z=0)
    r.add(sfu)
    window.show(r)
Esempio n. 2
0
def visualize(evals,evecs,viz_scale=0.5, fname='tensor_ellipsoids.png', size=(1000,1000)):
    # Do vizualisation
    interactive = True

    ren = window.Scene()

    from dipy.data import get_sphere
    #sphere = get_sphere('symmetric362')
    #sphere = get_sphere('repulsion724')
    sphere = get_sphere('symmetric642')

    # Calculate the colors. See dipy documentation.
    from dipy.reconst.dti import fractional_anisotropy, color_fa
    FA = fractional_anisotropy(evals)
    #print(FA)
    FA[np.isnan(FA)] = 0
    FA = np.clip(FA, 0, 1)
    RGB = color_fa(FA, evecs)
    k=0
    cfa = RGB[:, :, k:k+1]
    # Normalizing like this increases the contrast, but this will make the contrast different across plots
    #cfa /= cfa.max()

    # imgplot = plt.imshow(FA, cmap='gray')
    # plt.show()


    ren.add(actor.tensor_slicer(evals, evecs, sphere=sphere, scalar_colors=cfa, scale=viz_scale, norm=False))

    if interactive:
        window.show(ren)

    window.record(ren, n_frames=1, out_path=fname, size=(1000, 1000))
Esempio n. 3
0
def plot_directions(peak_dirs, peak_values, x_angle, y_angle, size=(300, 300)):
    """
    Opens a 3-d fury window of the maximum peaks visualized
    
    To show a slice, provide a sliced volume of the peak_dirs and peak_values and adjust the x_angle and y_angle 
    See Tractography Directional Field QA Tutorial for examples
    
    Parameters
    -----------
    peak_dirs: np array
        peak directional vector (x,y,z)  
    peak_values: np array
        peak values/magnitude 
    x_angle: int
        angle to rotate image along x axis
    y_angle: int
        angle to rotate image along y axis
    size: tuple
        size of fury window 
    """
    centers, directions, directions_colors, heights = generate_3_d_directions(
        peak_dirs, peak_values
    )

    scene = window.Scene()
    arrow_actor = actor.arrow(centers, directions, directions_colors, heights)

    scene.add(arrow_actor)
    scene.roll(x_angle)
    scene.pitch(y_angle)

    window.show(scene, size=size)
Esempio n. 4
0
def show_weighted_tractography(folder_name,
                               vec_vols,
                               s_list,
                               bundle_short,
                               direction,
                               downsamp=1):
    s_img = rf'{folder_name}\streamlines\ax_fa_corr_{bundle_short}_{direction}.png'

    if downsamp != 1:
        vec_vols = vec_vols[::downsamp]
        s_list = s_list[::downsamp]
    vec_vols.append(1)
    vec_vols.append(-1)
    cmap = create_colormap(np.asarray(vec_vols), name='seismic')
    vec_vols = vec_vols[:-2]
    cmap = cmap[:-2]
    print(min(vec_vols), max(vec_vols))
    #w_actor = actor.line(s_list, vec_vols, linewidth=1.2, lookup_colormap=cmap)
    w_actor = actor.line(s_list, cmap, linewidth=1.2)
    r = window.Scene()
    #r.SetBackground(*window.colors.white)
    r.add(w_actor)
    #r.add(bar)
    window.show(r)
    r.set_camera(r.camera_info())
    window.record(r, out_path=s_img, size=(800, 800))
Esempio n. 5
0
def calculate_fodf(gtab,
                   images,
                   name,
                   sphere=default_sphere,
                   radius=10,
                   fa_threshold=0.7):
    response, ratio = auto_response(gtab,
                                    images,
                                    roi_radius=radius,
                                    fa_thr=fa_threshold)

    csd_model = ConstrainedSphericalDeconvModel(gtab, response)
    csd_fit = csd_model.fit(images)
    csd_odf = csd_fit.odf(sphere)
    fodf_spheres = actor.odf_slicer(csd_odf,
                                    sphere=sphere,
                                    scale=0.9,
                                    norm=False,
                                    colormap='plasma')

    ren = window.Scene()
    ren.add(fodf_spheres)

    print('Saving illustration as csd_odfs_{}.png'.format(name))
    window.record(ren,
                  out_path='results/csd_odfs_{}.png'.format(name),
                  size=(600, 600))
    return csd_fit
Esempio n. 6
0
def visualize_roi(roi,
                  affine_or_mapping=None,
                  static_img=None,
                  roi_affine=None,
                  static_affine=None,
                  reg_template=None,
                  scene=None,
                  color=np.array([1, 0, 0]),
                  opacity=1.0,
                  inline=False,
                  interact=False):
    """
    Render a region of interest into a VTK viz as a volume
    """
    if not isinstance(roi, np.ndarray):
        if isinstance(roi, str):
            roi = nib.load(roi).get_fdata()
        else:
            roi = roi.get_fdata()

    if affine_or_mapping is not None:
        if isinstance(affine_or_mapping, np.ndarray):
            # This is an affine:
            if (static_img is None or roi_affine is None
                    or static_affine is None):
                raise ValueError(
                    "If using an affine to transform an ROI, "
                    "need to also specify all of the following",
                    "inputs: `static_img`, `roi_affine`, ", "`static_affine`")
            roi = reg.resample(roi, static_img, roi_affine, static_affine)
        else:
            # Assume it is  a mapping:
            if (isinstance(affine_or_mapping, str)
                    or isinstance(affine_or_mapping, nib.Nifti1Image)):
                if reg_template is None or static_img is None:
                    raise ValueError(
                        "If using a mapping to transform an ROI, need to ",
                        "also specify all of the following inputs: ",
                        "`reg_template`, `static_img`")
                affine_or_mapping = reg.read_mapping(affine_or_mapping,
                                                     static_img, reg_template)

            roi = auv.patch_up_roi(
                affine_or_mapping.transform_inverse(
                    roi, interpolation='nearest')).astype(bool)

    if scene is None:
        scene = window.Scene()

    roi_actor = actor.contour_from_roi(roi, color=color, opacity=opacity)
    scene.add(roi_actor)

    if inline:
        tdir = tempfile.gettempdir()
        fname = op.join(tdir, "fig.png")
        window.snapshot(scene, fname=fname)
        display.display_png(display.Image(fname))

    return _inline_interact(scene, inline, interact)
Esempio n. 7
0
    def build_scene(self):

        self.mem = GlobalHorizon()
        scene = window.Scene()
        self.add_cluster_actors(scene, self.tractograms,
                                self.cluster_thr,
                                enable_callbacks=False)
        return scene
Esempio n. 8
0
    def test_scene(self):
        xyz = 10 * np.random.rand(100, 3)
        colors = np.random.rand(100, 4)
        radii = np.random.rand(100) + 0.5

        sphere_actor = actor.sphere(centers=xyz, colors=colors, radii=radii)

        scene = window.Scene()
        scene.add(sphere_actor)
        
        self.assertEqual((0, 0), scene.GetSize())
def plotTract(tractIn):
    import numpy as np
    from dipy.viz import window, actor
    renderer = window.Scene()
    stream_actor = actor.line(tractIn)
    #renderer.set_camera(position=(-176.42, 118.52, 128.20),
    #               focal_point=(113.30, 128.31, 76.56),
    #                view_up=(0.18, 0.00, 0.98))
    get_ipython().run_line_magic('matplotlib', 'inline')
    renderer.add(stream_actor)
    
    window.show(renderer, size=(600, 600), reset_camera=True)
Esempio n. 10
0
def plot_streamlines(streamlines):
    if has_fury:
        # Prepare the display objects.
        color = colormap.line_colors(streamlines)

        streamlines_actor = actor.line(streamlines,
                                       colormap.line_colors(streamlines))

        # Create the 3D display.
        scene = window.Scene()
        scene.add(streamlines_actor)

        # Save still images for this static example. Or for interactivity use
        window.show(scene)
def show_template_bundles(bundles, static, show=True, fname=None):
    scene = window.Scene()
    template_actor = actor.slicer(static)
    scene.add(template_actor)

    lines_actor = actor.streamtube(bundles,
                                   window.colors.orange,
                                   linewidth=0.3)
    scene.add(lines_actor)

    if show:
        window.show(scene)
    if fname is not None:
        window.record(scene, n_frames=1, out_path=fname, size=(900, 900))
Esempio n. 12
0
def show_both_bundles(bundles, colors=None, show=True, fname=None):

    scene = window.Scene()
    scene.SetBackground(1., 1, 1)
    for (i, bundle) in enumerate(bundles):
        color = colors[i]
        streamtube_actor = actor.streamtube(bundle, color, linewidth=0.3)
        streamtube_actor.RotateX(-90)
        streamtube_actor.RotateZ(90)
        scene.add(streamtube_actor)
    if show:
        window.show(scene)
    elif fname is not None:
        window.record(scene, out_path=fname, size=(900, 900))
Esempio n. 13
0
def viewclusters(clusters,streamlines, outpath=None, interactive=False):
    #Linked to viewing clusters. If outpath given, will save info to right location, if interactive, will show window
    colormap = actor.create_colormap(np.ravel(clusters.centroids))
    colormap_full = np.ones((len(streamlines), 3))
    for cluster, color in zip(clusters, colormap):
        colormap_full[cluster.indices] = color

    scene = window.Scene()
    scene.SetBackground(1, 1, 1)
    scene.add(actor.streamtube(streamlines, colormap_full))
    window.record(scene, out_path=outpath, size=(600, 600))

    # Enables/disables interactive visualization
    if interactive:
        window.show(scene)
Esempio n. 14
0
def visualize_streamline(darray,
                         score,
                         save_able=False,
                         save_name='default.png',
                         control_par=1,
                         hue=[0.5, 1]):

    data_evl = darray

    streamlines_evl = Streamlines()

    for i in range(np.shape(data_evl)[0]):
        tmp = data_evl[i]
        tmp = zero_remove(tmp)
        #tmp = tmp[~np.all(tmp == 0, axis=-1)]
        #tmp = np.around(tmp, decimals=0)
        streamlines_evl.append(tmp)

    mse_nor = score

    # Visualize the streamlines, colored by cci
    ren = window.Scene()

    saturation = [0.0, 1.0]

    lut_cmap = actor.colormap_lookup_table(
        scale_range=(min(mse_nor), max(mse_nor) / control_par),
        hue_range=hue,
        saturation_range=saturation)

    bar3 = actor.scalar_bar(lut_cmap)
    ren.add(bar3)

    stream_actor = actor.line(streamlines_evl,
                              mse_nor,
                              linewidth=0.1,
                              lookup_colormap=lut_cmap)
    ren.add(stream_actor)

    if not save_able:
        interactive = True
        if interactive:
            window.show(ren)

    if save_able:
        window.record(ren, n_frames=1, out_path=save_name, size=(800, 800))
Esempio n. 15
0
def show_tracts(hue, saturation, scale, streamlines, mean_vol_per_tract,
                folder_name, fig_type):
    from dipy.viz import window, actor
    lut_cmap = actor.colormap_lookup_table(hue_range=hue,
                                           saturation_range=saturation,
                                           scale_range=scale)
    streamlines_actor = actor.streamtube(streamlines,
                                         mean_vol_per_tract,
                                         linewidth=0.5,
                                         lookup_colormap=lut_cmap)
    bar = actor.scalar_bar(lut_cmap)
    r = window.Scene()
    r.add(streamlines_actor)
    r.add(bar)
    mean_pasi_weighted_img = f'{folder_name}{os.sep}streamlines{os.sep}mean_pasi_weighted{fig_type}.png'
    window.show(r)
    r.set_camera(r.camera_info())
    window.record(r, out_path=mean_pasi_weighted_img, size=(800, 800))
Esempio n. 16
0
def calculate_odf(gtab, data, sh_order=4):
    csamodel = CsaOdfModel(gtab, sh_order)

    data_small = data[30:65, 40:75, 39:40]
    csa_odf = csamodel.fit(data_small).odf(default_sphere)
    csa_odf = np.clip(csa_odf, 0, np.max(csa_odf, -1)[..., None])
    odf_spheres = actor.odf_slicer(csa_odf,
                                   sphere=default_sphere,
                                   scale=0.9,
                                   norm=False,
                                   colormap='plasma')

    ren = window.Scene()
    ren.add(odf_spheres)

    print('Saving illustration as csa_odfs_{}.png'.format(data.shape[-1] - 1))
    window.record(ren,
                  out_path='results/csa_odfs_{}.png'.format(data.shape[-1] -
                                                            1),
                  size=(600, 600))
    return csa_odf
Esempio n. 17
0
def qball(gtab, data, name, sh_order=4):
    qballmodel = QballModel(gtab, sh_order)

    data_small = data[:, :, 39:40]
    qball_fit = qballmodel.fit(data_small)
    qball_odf = qball_fit.odf(default_sphere)
    odf_spheres = actor.odf_slicer(qball_odf,
                                   sphere=default_sphere,
                                   scale=0.9,
                                   norm=False,
                                   colormap='plasma')

    ren = window.Scene()
    ren.add(odf_spheres)

    print('Saving illustration as qball_odfs_{}.png'.format(
        name))  #data.shape[-1] - 1))
    window.record(ren,
                  out_path='results/qball_odfs_{}.png'.format(name),
                  size=(600, 600))
    return qball_odf, qball_fit.shm_coeff
Esempio n. 18
0
    def visualize(self,
                  out_path='out/',
                  outer_box=True,
                  axes=True,
                  clip_neg=False,
                  azimuth=0,
                  elevation=0,
                  n_frames=1,
                  mag=1,
                  video=False,
                  viz_type='ODF',
                  mask=None,
                  mask_roi=None,
                  skip_n=1,
                  skip_n_roi=1,
                  scale=1,
                  roi_scale=1,
                  zoom_start=1.0,
                  zoom_end=1.0,
                  top_zoom=1,
                  interact=False,
                  save_parallels=False,
                  my_cam=None,
                  compress=True,
                  roi=None,
                  corner_text='',
                  scalemap=None,
                  titles_on=True,
                  scalebar_on=True,
                  invert=False,
                  flat=False,
                  colormap='bwr',
                  global_cm=True,
                  camtilt=False,
                  axes_on=False,
                  colors=None,
                  arrows=None,
                  arrow_color=np.array([0, 0, 0]),
                  linewidth=0.1,
                  mark_slices=None,
                  z_shift=0,
                  profiles=[],
                  markers=[],
                  marker_colors=[],
                  marker_scale=1,
                  normalize_glyphs=True,
                  gamma=1,
                  density_max=1):
        log.info('Preparing to render ' + out_path)

        # Handle scalemap
        if scalemap is None:
            scalemap = util.ScaleMap(min=np.min(self.f[..., 0]),
                                     max=np.max(self.f[..., 0]))

        # Prepare output
        util.mkdir(out_path)

        # Setup vtk renderers
        renWin = vtk.vtkRenderWindow()

        if not interact:
            renWin.SetOffScreenRendering(1)
        if isinstance(viz_type, str):
            viz_type = [viz_type]

        # Rows and columns
        cols = len(viz_type)
        if roi is None:
            rows = 1
        else:
            rows = 2

        renWin.SetSize(np.int(500 * mag * cols), np.int(500 * mag * rows))

        # Select background color
        if save_parallels:
            bg_color = [1, 1, 1]
            line_color = np.array([0, 0, 0])
            line_bcolor = np.array([1, 1, 1])
        else:
            if not invert:
                bg_color = [0, 0, 0]
                line_color = np.array([1, 1, 1])
                line_bcolor = np.array([0, 0, 0])
            else:
                bg_color = [1, 1, 1]
                line_color = np.array([0, 0, 0])
                line_bcolor = np.array([1, 1, 1])

        # For each viz_type
        rens = []
        zoom_start = []
        zoom_end = []
        for row in range(rows):
            for col in range(cols):
                # Render
                ren = window.Scene()
                rens.append(ren)
                if viz_type[col] is 'Density':
                    ren.background([0, 0, 0])
                    line_color = np.array([1, 1, 1])
                else:
                    ren.background(bg_color)
                ren.SetViewport(col / cols, (rows - row - 1) / rows,
                                (col + 1) / cols, (rows - row) / rows)
                renWin.AddRenderer(ren)
                iren = vtk.vtkRenderWindowInteractor()
                iren.SetRenderWindow(renWin)

                # Mask
                if mask is None:
                    mask = np.ones((self.X, self.Y, self.Z), dtype=np.bool)
                if mask_roi is None:
                    mask_roi = mask

                # Main vs roi
                if row == 0:
                    data = self.f
                    skip_mask = np.zeros(mask.shape, dtype=np.bool)
                    skip_mask[::skip_n, ::skip_n, ::skip_n] = 1
                    my_mask = np.logical_and(mask, skip_mask)
                    scale = scale
                    scalemap = scalemap
                    if np.sum(my_mask) == 0:
                        my_mask[0, 0, 0] = True
                else:
                    data = self.f[roi[0][0]:roi[1][0], roi[0][1]:roi[1][1],
                                  roi[0][2]:roi[1][2], :]
                    roi_mask = mask_roi[roi[0][0]:roi[1][0],
                                        roi[0][1]:roi[1][1],
                                        roi[0][2]:roi[1][2]]
                    skip_mask = np.zeros(roi_mask.shape, dtype=np.bool)
                    skip_mask[::skip_n_roi, ::skip_n_roi, ::skip_n_roi] = 1
                    my_mask = np.logical_and(roi_mask, skip_mask)
                    scale = roi_scale
                    scalemap = scalemap

                # Add visuals to renderer
                if viz_type[col] == "ODF":
                    renWin.SetMultiSamples(4)
                    log.info('Rendering ' + str(np.sum(my_mask)) + ' ODFs')
                    fodf_spheres = viz.odf_sparse(data,
                                                  self.Binv,
                                                  sphere=self.sphere,
                                                  scale=skip_n * scale * 0.5,
                                                  norm=False,
                                                  colormap=colormap,
                                                  mask=my_mask,
                                                  global_cm=global_cm,
                                                  scalemap=scalemap,
                                                  odf_sphere=False,
                                                  flat=flat,
                                                  normalize=normalize_glyphs)

                    ren.add(fodf_spheres)
                elif viz_type[col] == "ODF Sphere":
                    renWin.SetMultiSamples(4)
                    log.info('Rendering ' + str(np.sum(my_mask)) + ' ODFs')
                    fodf_spheres = viz.odf_sparse(data,
                                                  self.Binv,
                                                  sphere=self.sphere,
                                                  scale=skip_n * scale * 0.5,
                                                  norm=False,
                                                  colormap=colormap,
                                                  mask=my_mask,
                                                  global_cm=global_cm,
                                                  scalemap=scalemap,
                                                  odf_sphere=True,
                                                  flat=flat)
                    ren.add(fodf_spheres)
                elif viz_type[col] == "Ellipsoid":
                    renWin.SetMultiSamples(4)
                    log.info(
                        'Warning: scaling is not implemented for ellipsoids')
                    log.info('Rendering ' + str(np.sum(my_mask)) +
                             ' ellipsoids')
                    fodf_peaks = viz.tensor_slicer_sparse(data,
                                                          sphere=self.sphere,
                                                          scale=skip_n *
                                                          scale * 0.5,
                                                          mask=my_mask)
                    ren.add(fodf_peaks)
                elif viz_type[col] == "Peak":
                    renWin.SetMultiSamples(4)
                    log.info('Rendering ' + str(np.sum(my_mask)) + ' peaks')
                    fodf_peaks = viz.peak_slicer_sparse(
                        data,
                        self.Binv,
                        self.sphere.vertices,
                        linewidth=linewidth,
                        scale=skip_n * scale * 0.5,
                        colors=colors,
                        mask=my_mask,
                        scalemap=scalemap,
                        normalize=normalize_glyphs)
                    fodf_peaks.GetProperty().LightingOn()
                    fodf_peaks.GetProperty().SetDiffuse(
                        0.4)  # Doesn't work (VTK bug I think)
                    fodf_peaks.GetProperty().SetAmbient(0.15)
                    fodf_peaks.GetProperty().SetSpecular(0)
                    fodf_peaks.GetProperty().SetSpecularPower(0)

                    ren.add(fodf_peaks)
                elif viz_type[col] == "Principal":
                    log.info(
                        'Warning: scaling is not implemented for principals')
                    log.info('Rendering ' + str(np.sum(my_mask)) +
                             ' principals')
                    fodf_peaks = viz.principal_slicer_sparse(
                        data,
                        self.Binv,
                        self.sphere.vertices,
                        scale=skip_n * scale * 0.5,
                        mask=my_mask)
                    ren.add(fodf_peaks)
                elif viz_type[col] == "Density":
                    renWin.SetMultiSamples(0)  # Must be zero for smooth
                    # renWin.SetAAFrames(4) # Slow antialiasing for volume renders
                    log.info('Rendering density')
                    gamma_corr = np.where(data[..., 0] > 0,
                                          data[..., 0]**gamma, data[..., 0])
                    scalemap.max = density_max * scalemap.max**gamma
                    volume = viz.density_slicer(gamma_corr, scalemap)
                    ren.add(volume)

                X = np.float(data.shape[0])
                Y = np.float(data.shape[1])
                Z = np.float(data.shape[2]) - z_shift

                # Titles
                if row == 0 and titles_on:
                    viz.add_text(ren, viz_type[col], 0.5, 0.96, mag)

                # Scale bar
                if col == cols - 1 and not save_parallels and scalebar_on:
                    yscale = 1e-3 * self.vox_dim[1] * data.shape[1]
                    yscale_label = '{:.2g}'.format(yscale) + ' um'
                    viz.add_text(ren, yscale_label, 0.5, 0.03, mag)
                    viz.draw_scale_bar(ren, X, Y, Z, [1, 1, 1])

                # Corner text
                if row == rows - 1 and col == 0 and titles_on:
                    viz.add_text(ren, corner_text, 0.03, 0.03, mag, ha='left')

                # Draw boxes
                Nmax = np.max([X, Y, Z])
                if outer_box:
                    if row == 0:
                        viz.draw_outer_box(
                            ren,
                            np.array([[0, 0, 0], [X, Y, Z]]) - 0.5, line_color)
                    if row == 1:
                        viz.draw_outer_box(
                            ren,
                            np.array([[0, 0, 0], [X, Y, Z]]) - 0.5, [0, 1, 1])

                # Add colored axes
                if axes:
                    viz.draw_axes(ren, np.array([[0, 0, 0], [X, Y, Z]]) - 0.5)

                # Add custom arrows
                if arrows is not None:
                    for i in range(arrows.shape[0]):
                        viz.draw_single_arrow(ren,
                                              arrows[i, 0, :],
                                              arrows[i, 1, :],
                                              color=arrow_color)
                        viz.draw_unlit_line(ren, [
                            np.array([arrows[i, 0, :], [X / 2, Y / 2, Z / 2]])
                        ], [arrow_color],
                                            lw=0.3,
                                            scale=1.0)

                # Draw roi box
                if row == 0 and roi is not None:
                    maxROI = np.max([
                        roi[1][0] - roi[0][0], roi[1][1] - roi[0][1],
                        roi[1][2] - roi[0][2]
                    ])
                    maxXYZ = np.max([self.X, self.Y, self.Z])
                    viz.draw_outer_box(ren,
                                       roi, [0, 1, 1],
                                       lw=0.3 * maxXYZ / maxROI)
                    viz.draw_axes(ren, roi, lw=0.3 * maxXYZ / maxROI)

                # Draw marked slices
                if mark_slices is not None:
                    for slicen in mark_slices:
                        md = np.max((X, Z))
                        frac = slicen / data.shape[1]
                        rr = 0.83 * md
                        t1 = 0
                        t2 = np.pi / 2
                        t3 = np.pi
                        t4 = 3 * np.pi / 2
                        points = [
                            np.array([[
                                X / 2 + rr * np.cos(t1), frac * Y,
                                Z / 2 + rr * np.sin(t1)
                            ],
                                      [
                                          X / 2 + rr * np.cos(t2), frac * Y,
                                          Z / 2 + rr * np.sin(t2)
                                      ],
                                      [
                                          X / 2 + rr * np.cos(t3), frac * Y,
                                          Z / 2 + rr * np.sin(t3)
                                      ],
                                      [
                                          X / 2 + rr * np.cos(t4), frac * Y,
                                          Z / 2 + rr * np.sin(t4)
                                      ],
                                      [
                                          X / 2 + rr * np.cos(t1), frac * Y,
                                          Z / 2 + rr * np.sin(t1)
                                      ],
                                      [
                                          X / 2 + rr * np.cos(t2), frac * Y,
                                          Z / 2 + rr * np.sin(t2)
                                      ]])
                        ]
                        viz.draw_unlit_line(ren,
                                            points,
                                            6 * [line_color + 0.6],
                                            lw=0.3,
                                            scale=1.0)

                # Draw markers
                for i, marker in enumerate(markers):
                    # Draw sphere
                    source = vtk.vtkSphereSource()
                    source.SetCenter(marker)
                    source.SetRadius(marker_scale)
                    source.SetThetaResolution(30)
                    source.SetPhiResolution(30)

                    # mapper
                    mapper = vtk.vtkPolyDataMapper()
                    mapper.SetInputConnection(source.GetOutputPort())

                    # actor
                    actor = vtk.vtkActor()
                    actor.SetMapper(mapper)
                    actor.GetProperty().SetColor(marker_colors[i, :])
                    actor.GetProperty().SetLighting(0)
                    ren.AddActor(actor)

                # Draw profile lines
                colors = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])

                for i, profile in enumerate(profiles):
                    import pdb
                    pdb.set_trace()
                    n_seg = profile.shape[0]
                    viz.draw_unlit_line(ren, [profile],
                                        n_seg * [colors[i, :]],
                                        lw=0.5,
                                        scale=1.0)

                    # Draw sphere
                    source = vtk.vtkSphereSource()
                    source.SetCenter(profile[0])
                    source.SetRadius(1)
                    source.SetThetaResolution(30)
                    source.SetPhiResolution(30)

                    # mapper
                    mapper = vtk.vtkPolyDataMapper()
                    mapper.SetInputConnection(source.GetOutputPort())

                    # actor
                    actor = vtk.vtkActor()
                    actor.SetMapper(mapper)
                    # actor.GetProperty().SetColor(colors[i,:])
                    actor.GetProperty().SetLighting(0)

                    # assign actor to the renderer
                    ren.AddActor(actor)

                # Setup cameras
                Rmax = np.linalg.norm([Z / 2, X / 2, Y / 2])
                Rcam_rad = Rmax / np.tan(np.pi / 12)
                Ntmax = np.max([X, Y])
                ZZ = Z
                if ZZ > Ntmax:
                    Rcam_edge = np.max([X / 2, Y / 2])
                else:
                    Rcam_edge = np.min([X / 2, Y / 2])
                Rcam = Rcam_edge + Rcam_rad
                if my_cam is None:
                    cam = ren.GetActiveCamera()
                    if camtilt:
                        cam.SetPosition(
                            ((X - 1) / 2, (Y - 1) / 2, (Z - 1) / 2 + Rcam))
                        cam.SetViewUp((-1, 0, 1))
                        if axes_on:
                            max_dim = np.max((X, Z))
                            viz.draw_unlit_line(ren, [
                                np.array([[(X - max_dim) / 2, Y / 2, Z / 2],
                                          [X / 2, Y / 2, +Z / 2],
                                          [X / 2, Y / 2, (Z + max_dim) / 2]])
                            ],
                                                3 * [line_color],
                                                lw=max_dim / 250,
                                                scale=1.0)
                    else:
                        cam.SetPosition(
                            ((X - 1) / 2 + Rcam, (Y - 1) / 2, (Z - 1) / 2))
                        cam.SetViewUp((0, 0, 1))
                    cam.SetFocalPoint(((X - 1) / 2, (Y - 1) / 2, (Z - 1) / 2))
                    #ren.reset_camera()
                else:
                    ren.set_camera(*my_cam)
                ren.azimuth(azimuth)
                ren.elevation(elevation)

                # Set zooming
                if save_parallels:
                    zoom_start.append(1.7)
                    zoom_end.append(1.7)
                else:
                    if row == 0:
                        zoom_start.append(1.3 * top_zoom)
                        zoom_end.append(1.3 * top_zoom)
                    else:
                        zoom_start.append(1.3)
                        zoom_end.append(1.3)

        # Setup writer
        writer = vtk.vtkTIFFWriter()
        if not compress:
            writer.SetCompressionToNoCompression()

        # Execute renders
        az = 90
        naz = np.ceil(360 / n_frames)
        log.info('Rendering ' + out_path)
        if save_parallels:
            # Parallel rendering for summaries
            filenames = ['yz', 'xy', 'xz']
            zooms = [zoom_start[0], 1.0, 1.0]
            azs = [90, -90, 0]
            els = [0, 0, 90]
            ren.projection(proj_type='parallel')
            ren.reset_camera()
            for i in tqdm(range(3)):
                ren.zoom(zooms[i])
                ren.azimuth(azs[i])
                ren.elevation(els[i])
                ren.reset_clipping_range()
                renderLarge = vtk.vtkRenderLargeImage()
                renderLarge.SetMagnification(1)
                renderLarge.SetInput(ren)
                renderLarge.Update()
                writer.SetInputConnection(renderLarge.GetOutputPort())
                writer.SetFileName(out_path + filenames[i] + '.tif')
                writer.Write()
        else:
            # Rendering for movies
            for j, ren in enumerate(rens):
                ren.zoom(zoom_start[j])
            for i in tqdm(range(n_frames)):
                for j, ren in enumerate(rens):
                    ren.zoom(1 + ((zoom_end[j] - zoom_start[j]) / n_frames))
                    ren.azimuth(az)
                    ren.reset_clipping_range()

                renderLarge = vtk.vtkRenderLargeImage()
                renderLarge.SetMagnification(1)
                renderLarge.SetInput(ren)
                renderLarge.Update()
                writer.SetInputConnection(renderLarge.GetOutputPort())
                if n_frames != 1:
                    writer.SetFileName(out_path + str(i).zfill(3) + '.tif')
                else:
                    writer.SetFileName(out_path + '.tif')
                writer.Write()
                az = naz

        # Interactive
        if interact:
            window.show(ren)

        # Generate video (requires ffmpeg)
        if video:
            log.info('Generating video from frames')
            fps = np.ceil(n_frames / 12)
            subprocess.call([
                'ffmpeg', '-nostdin', '-y', '-framerate',
                str(fps), '-loglevel', 'panic', '-i',
                out_path + '%03d' + '.png', '-pix_fmt', 'yuvj420p', '-vcodec',
                'mjpeg', out_path[:-1] + '.avi'
            ])
            # subprocess.call(['rm', '-r', out_path])

        return my_cam
Esempio n. 19
0
def visualize_roi(roi,
                  affine_or_mapping=None,
                  static_img=None,
                  roi_affine=None,
                  static_affine=None,
                  reg_template=None,
                  name='ROI',
                  figure=None,
                  color=np.array([1, 0, 0]),
                  flip_axes=None,
                  opacity=1.0,
                  inline=False,
                  interact=False):
    """
    Render a region of interest into a VTK viz as a volume

    Parameters
    ----------
    roi : str or Nifti1Image
        The ROI information

    affine_or_mapping : ndarray, Nifti1Image, or str, optional
       An affine transformation or mapping to apply to the ROIs before
       visualization. Default: no transform.

    static_img: str or Nifti1Image, optional
        Template to resample roi to.
        Default: None

    roi_affine: ndarray, optional
        Default: None

    static_affine: ndarray, optional
        Default: None

    reg_template: str or Nifti1Image, optional
        Template to use for registration.
        Default: None

    name: str, optional
        Name of ROI for the legend.
        Default: 'ROI'

    color : ndarray, optional
        RGB color for ROI.
        Default: np.array([1, 0, 0])

    flip_axes : None
        This parameter is to conform fury and plotly APIs.

    opacity : float, optional
        Opacity of ROI.
        Default: 1.0

    figure : fury Scene object, optional
        If provided, the visualization will be added to this Scene. Default:
        Initialize a new Scene.

    interact : bool
        Whether to provide an interactive VTK window for interaction.
        Default: False

    inline : bool
        Whether to embed the visualization inline in a notebook. Only works
        in the notebook context. Default: False.

    Returns
    -------
    Fury Scene object
    """
    roi = vut.prepare_roi(roi, affine_or_mapping, static_img, roi_affine,
                          static_affine, reg_template)

    if figure is None:
        figure = window.Scene()

    roi_actor = actor.contour_from_roi(roi, color=color, opacity=opacity)
    figure.add(roi_actor)

    return _inline_interact(figure, inline, interact)
Esempio n. 20
0
def viewstreamlines_anat(streamlines_full,
                         anat_path,
                         affine,
                         ratio=1,
                         threshold=10.,
                         verbose=False):

    scene = window.Scene()
    scene.SetBackground(1, 1, 1)

    #colors = ['white', 'cadmium_red_deep', 'misty_rose', 'slate_grey_dark', 'ivory_black', 'chartreuse']
    colors = [
        window.colors.white, window.colors.cadmium_red_deep,
        window.colors.misty_rose, window.colors.slate_grey_dark,
        window.colors.ivory_black, window.colors.chartreuse
    ]
    streamline_cut = []
    i = 0
    if ratio != 1:
        for streamline in streamlines_full:
            if i % ratio == 0:
                streamline_cut.append(streamline)
            i += 1
    else:
        streamline_cut = streamlines_full
    qb = QuickBundles(threshold=threshold)
    clusters = qb.cluster(streamline_cut)

    if verbose:
        print("Nb. clusters:", len(clusters))
        print("Cluster sizes:", map(len, clusters))
        print("Small clusters:", clusters < 10)
        print("Streamlines indices of the first cluster:\n",
              clusters[0].indices)
        print("Centroid of the last clustker:\n", clusters[-1].centroid)

    j = 0
    scene = window.Scene()
    scene.add(actor.streamtube(streamline_cut, colors[j]))
    slicer_opacity = 0.6
    j += 1

    if isinstance(anat_path, str) and os.path.exists(anat_path):
        anat_nifti = load_nifti(anat_path)
        try:
            data = anat_nifti.data
        except AttributeError:
            data = anat_nifti[0]
        if affine is None:
            try:
                affine = anat_nifti.affine
            except AttributeError:
                affine = anat_nifti[1]
    else:
        data = anat_path

    shape = np.shape(data)
    if np.size(shape) == 4:
        data = data[:, :, :, 0]
    image_actor_z = actor.slicer(data, affine)
    image_actor_z.opacity(slicer_opacity)

    image_actor_x = image_actor_z.copy()
    x_midpoint = int(np.round(shape[0] / 2))
    image_actor_x.display_extent(x_midpoint, x_midpoint, 0, shape[1] - 1, 0,
                                 shape[2] - 1)

    image_actor_y = image_actor_z.copy()
    y_midpoint = int(np.round(shape[1] / 2))
    image_actor_y.display_extent(0, shape[0] - 1, y_midpoint, y_midpoint, 0,
                                 shape[2] - 1)

    scene.add(image_actor_z)
    scene.add(image_actor_x)
    scene.add(image_actor_y)
    global size
    size = scene.GetSize()
    show_m = window.ShowManager(scene, size=(1200, 900))
    show_m.initialize()

    interactive = True
    interactive = False
    if interactive:

        show_m.add_window_callback(win_callback)
        show_m.render()
        show_m.start()
    else:
        window.record(scene,
                      out_path='bundles_and_3_slices.png',
                      size=(1200, 900),
                      reset_camera=False)
Esempio n. 21
0
def connective_streamlines_figuremaker(allstreamlines,
                                       ROI_streamlines,
                                       ROI_names,
                                       anat_path,
                                       threshold=10.,
                                       verbose=False):

    #streamlines = Streamlines(res['af.left'])
    #streamlines.extend(res['cst.right'])
    #streamlines.extend(res['cc_1'])
    world_coords = True

    # Cluster sizes: [64, 191, 47, 1]

    # Small clusters: array([False, False, False, True], dtype=bool)

    scene = window.Scene()
    scene.SetBackground(1, 1, 1)

    colors = [
        'white', 'cadmium_red_deep', 'misty_rose', 'slate_grey_dark',
        'ivory_black', 'chartreuse'
    ]
    colors = [
        window.colors.white, window.colors.cadmium_red_deep,
        window.colors.misty_rose, window.colors.slate_grey_dark,
        window.colors.ivory_black, window.colors.chartreuse
    ]
    i = 0
    for ROI in ROI_streamlines:
        ROI_streamline = allstreamlines[ROI]
        qb = QuickBundles(threshold=threshold)
        clusters = qb.cluster(ROI_streamline)
        if verbose:
            print("Nb. clusters:", len(clusters))
            print("Cluster sizes:", map(len, clusters))
            print("Small clusters:", clusters < 10)
            print("Streamlines indices of the first cluster:\n",
                  clusters[0].indices)
            print("Centroid of the last cluster:\n", clusters[-1].centroid)
        #if not world_coords:
        #    from dipy.tracking.streamline import transform_streamlines
        #    streamlines = transform_streamlines(ROI_streamline, np.linalg.inv(affine))
        scene = window.Scene()
        #stream_actor = actor.line(ROI_streamline)
        #scene.add(actor.streamtube(ROI_streamline, window.colors.misty_rose))
        scene.add(actor.streamtube(ROI_streamline, colors[i]))

        #if not world_coords:
        #    image_actor_z = actor.slicer(data, affine=np.eye(4))
        #else:
        #    image_actor_z = actor.slicer(data, affine)

        slicer_opacity = 0.6
        i = i + 1

    anat_nifti = load_nifti(anat_path)
    try:
        data = anat_nifti.data
    except AttributeError:
        data = anat_nifti[0]
    try:
        affine = anat_nifti.affine
    except AttributeError:
        affine = anat_nifti[1]
    shape = np.shape(data)
    image_actor_z = actor.slicer(data[:, :, :, 0], affine)
    image_actor_z.opacity(slicer_opacity)

    image_actor_x = image_actor_z.copy()
    x_midpoint = int(np.round(shape[0] / 2))
    image_actor_x.display_extent(x_midpoint, x_midpoint, 0, shape[1] - 1, 0,
                                 shape[2] - 1)

    image_actor_y = image_actor_z.copy()
    y_midpoint = int(np.round(shape[1] / 2))
    image_actor_y.display_extent(0, shape[0] - 1, y_midpoint, y_midpoint, 0,
                                 shape[2] - 1)

    scene.add(image_actor_z)
    scene.add(image_actor_x)
    scene.add(image_actor_y)
    global size
    size = scene.GetSize()
    show_m = window.ShowManager(scene, size=(1200, 900))
    show_m.initialize()

    interactive = True
    interactive = False
    if interactive:

        show_m.add_window_callback(win_callback)
        show_m.render()
        show_m.start()
    else:
        window.record(scene,
                      out_path='bundles_and_3_slices.png',
                      size=(1200, 900),
                      reset_camera=False)
Esempio n. 22
0
def plot_bundles_with_metric(bundle_path,
                             endings_path,
                             brain_mask_path,
                             bundle,
                             metrics,
                             output_path,
                             tracking_format="trk_legacy",
                             show_color_bar=True):
    import seaborn as sns  # import in function to avoid error if not installed (this is only needed in this function)
    from dipy.viz import actor, window
    from tractseg.libs import vtk_utils

    def _add_extra_point_to_last_streamline(sl):
        # Coloring broken as soon as all streamlines have same number of points -> why???
        # Add one number to last streamline to make it have a different number
        sl[-1] = np.append(sl[-1], [sl[-1][-1]], axis=0)
        return sl

    # Settings
    NR_SEGMENTS = 100
    ANTI_INTERPOL_MULT = 1  # increase number of points to avoid interpolation to blur the colors
    algorithm = "distance_map"  # equal_dist | distance_map | cutting_plane
    # colors = np.array(sns.color_palette("coolwarm", NR_SEGMENTS))  # colormap blue to red (does not fit to colorbar)
    colors = np.array(sns.light_palette(
        "red", NR_SEGMENTS))  # colormap only red, which fits to color_bar
    img_size = (1000, 1000)

    # Tractometry skips first and last element. Therefore we only have 98 instead of 100 elements.
    # Here we duplicate the first and last element to get back to 100 elements
    metrics = list(metrics)
    metrics = np.array([metrics[0]] + metrics + [metrics[-1]])

    metrics_max = metrics.max()
    metrics_min = metrics.min()
    if metrics_max == metrics_min:
        metrics = np.zeros(len(metrics))
    else:
        metrics = img_utils.scale_to_range(
            metrics,
            range=(0, 99))  # range needs to be same as segments in colormap

    orientation = dataset_specific_utils.get_optimal_orientation_for_bundle(
        bundle)

    # Load mask
    beginnings_img = nib.load(endings_path)
    beginnings = beginnings_img.get_fdata().astype(np.uint8)
    for i in range(1):
        beginnings = binary_dilation(beginnings)

    # Load trackings
    if tracking_format == "trk_legacy":
        streams, hdr = trackvis.read(bundle_path)
        streamlines = [s[0] for s in streams]
    else:
        sl_file = nib.streamlines.load(bundle_path)
        streamlines = sl_file.streamlines

    # Reduce streamline count
    streamlines = streamlines[::2]

    # Reorder to make all streamlines have same start region
    streamlines = fiber_utils.add_to_each_streamline(streamlines, 0.5)
    streamlines_new = []
    for idx, sl in enumerate(streamlines):
        startpoint = sl[0]
        # Flip streamline if not in right order
        if beginnings[int(startpoint[0]),
                      int(startpoint[1]),
                      int(startpoint[2])] == 0:
            sl = sl[::-1, :]
        streamlines_new.append(sl)
    streamlines = fiber_utils.add_to_each_streamline(streamlines_new, -0.5)

    if algorithm == "distance_map" or algorithm == "equal_dist":
        streamlines = fiber_utils.resample_fibers(
            streamlines, NR_SEGMENTS * ANTI_INTERPOL_MULT)
    elif algorithm == "cutting_plane":
        streamlines = fiber_utils.resample_to_same_distance(
            streamlines,
            max_nr_points=NR_SEGMENTS,
            ANTI_INTERPOL_MULT=ANTI_INTERPOL_MULT)

    # Cut start and end by percentage
    # streamlines = FiberUtils.resample_fibers(streamlines, NR_SEGMENTS * ANTI_INTERPOL_MULT)
    # remove = int((NR_SEGMENTS * ANTI_INTERPOL_MULT) * 0.15)  # remove X% in beginning and end
    # streamlines = np.array(streamlines)[:, remove:-remove, :]
    # streamlines = list(streamlines)

    if algorithm == "equal_dist":
        segment_idxs = []
        for i in range(len(streamlines)):
            segment_idxs.append(list(range(NR_SEGMENTS * ANTI_INTERPOL_MULT)))
        segment_idxs = np.array(segment_idxs)

    elif algorithm == "distance_map":
        metric = AveragePointwiseEuclideanMetric()
        qb = QuickBundles(threshold=100., metric=metric)
        clusters = qb.cluster(streamlines)
        centroids = Streamlines(clusters.centroids)
        _, segment_idxs = cKDTree(centroids.data, 1,
                                  copy_data=True).query(streamlines, k=1)

    elif algorithm == "cutting_plane":
        streamlines_resamp = fiber_utils.resample_fibers(
            streamlines, NR_SEGMENTS * ANTI_INTERPOL_MULT)
        metric = AveragePointwiseEuclideanMetric()
        qb = QuickBundles(threshold=100., metric=metric)
        clusters = qb.cluster(streamlines_resamp)
        centroid = Streamlines(clusters.centroids)[0]
        # index of the middle cluster
        middle_idx = int(NR_SEGMENTS / 2) * ANTI_INTERPOL_MULT
        middle_point = centroid[middle_idx]
        segment_idxs = fiber_utils.get_idxs_of_closest_points(
            streamlines, middle_point)
        # Align along the middle and assign indices
        segment_idxs_eqlen = []
        for idx, sl in enumerate(streamlines):
            sl_middle_pos = segment_idxs[idx]
            before_elems = sl_middle_pos
            after_elems = len(sl) - sl_middle_pos
            base_idx = 1000  # use higher index to avoid negative numbers for area below middle
            r = range((base_idx - before_elems), (base_idx + after_elems))
            segment_idxs_eqlen.append(r)
        segment_idxs = segment_idxs_eqlen

    # Add extra point otherwise coloring BUG
    streamlines = _add_extra_point_to_last_streamline(streamlines)

    renderer = window.Scene()
    colors_all = []  # final shape will be [nr_streamlines, nr_points, 3]
    for jdx, sl in enumerate(streamlines):
        colors_sl = []
        for idx, p in enumerate(sl):
            if idx >= len(segment_idxs[jdx]):
                seg_idx = segment_idxs[jdx][idx - 1]
            else:
                seg_idx = segment_idxs[jdx][idx]

            m = metrics[int(seg_idx / ANTI_INTERPOL_MULT)]
            color = colors[int(m)]
            colors_sl.append(color)
        colors_all.append(
            colors_sl
        )  # this can not be converted to numpy array because last element has one more elem

    sl_actor = actor.streamtube(streamlines,
                                colors=colors_all,
                                linewidth=0.2,
                                opacity=1)
    renderer.add(sl_actor)

    # plot brain mask
    mask = nib.load(brain_mask_path).get_fdata().astype(np.uint8)
    cont_actor = vtk_utils.contour_from_roi_smooth(
        mask,
        affine=beginnings_img.affine,
        color=[.9, .9, .9],
        opacity=.2,
        smoothing=50)
    renderer.add(cont_actor)

    if show_color_bar:
        lut_cmap = actor.colormap_lookup_table(scale_range=(metrics_min,
                                                            metrics_max),
                                               hue_range=(0.0, 0.0),
                                               saturation_range=(0.0, 1.0))
        renderer.add(actor.scalar_bar(lut_cmap))

    if orientation == "sagittal":
        renderer.set_camera(position=(-412.95, -34.38, 80.15),
                            focal_point=(102.46, -16.96, -11.71),
                            view_up=(0.1806, 0.0, 0.9835))
    elif orientation == "coronal":
        renderer.set_camera(position=(-48.63, 360.31, 98.37),
                            focal_point=(-20.16, 92.89, 36.02),
                            view_up=(-0.0047, -0.2275, 0.9737))
    elif orientation == "axial":
        pass
    else:
        raise ValueError("Invalid orientation provided")

    # Use this to interatively get new camera angle
    # window.show(renderer, size=img_size, reset_camera=False)
    # print(renderer.get_camera())

    window.record(renderer, out_path=output_path, size=img_size)
Esempio n. 23
0
def launch_quickbundles(streamlines,
                        outpath,
                        ROIname="all",
                        threshold=10.,
                        labelmask=None,
                        affine=np.eye(4),
                        interactive=False):

    #qb = QuickBundles(threshold=10.)
    qb = QuickBundles(threshold=threshold)
    clusters = qb.cluster(streamlines)

    print("Nb. clusters:", len(clusters))
    print("Cluster sizes:", map(len, clusters))
    print("Small clusters:", clusters < 10)
    print("Streamlines indices of the first cluster:\n", clusters[0].indices)
    print("Centroid of the last cluster:\n", clusters[-1].centroid)

    # Cluster sizes: [64, 191, 47, 1]

    # Small clusters: array([False, False, False, True], dtype=bool)

    scene = window.Scene()
    scene.SetBackground(1, 1, 1)
    scene.add(actor.streamtube(streamlines, window.colors.misty_rose))
    if labelmask is not None:
        shape = labelmask.shape
        image_actor_z = actor.slicer(labelmask, affine)
        slicer_opacity = 0.6
        image_actor_z.opacity(slicer_opacity)

        image_actor_x = image_actor_z.copy()
        x_midpoint = int(np.round(shape[0] / 2))
        image_actor_x.display_extent(x_midpoint, x_midpoint, 0, shape[1] - 1,
                                     0, shape[2] - 1)

        image_actor_y = image_actor_z.copy()
        y_midpoint = int(np.round(shape[1] / 2))
        image_actor_y.display_extent(0, shape[0] - 1, y_midpoint, y_midpoint,
                                     0, shape[2] - 1)
        scene.add(image_actor_z)
        scene.add(image_actor_x)
        scene.add(image_actor_y)

    window.record(scene,
                  out_path=outpath + ROIname + '_initial.png',
                  size=(600, 600))
    if interactive:
        window.show(scene)

    colormap = actor.create_colormap(np.arange(len(clusters)))

    scene.clear()
    scene.SetBackground(1, 1, 1)
    scene.add(actor.streamtube(streamlines, window.colors.white, opacity=0.05))
    scene.add(actor.streamtube(clusters.centroids, colormap, linewidth=0.4))
    if labelmask is not None:
        image_actor_z = actor.slicer(labelmask, affine)
    window.record(scene,
                  out_path=outpath + ROIname + '_centroids.png',
                  size=(600, 600))
    if interactive:
        window.show(scene)

    colormap_full = np.ones((len(streamlines), 3))
    for cluster, color in zip(clusters, colormap):
        colormap_full[cluster.indices] = color

    scene.clear()
    scene.SetBackground(1, 1, 1)
    scene.add(actor.streamtube(streamlines, colormap_full))
    window.record(scene,
                  out_path=outpath + ROIname + '_clusters.png',
                  size=(600, 600))
    if interactive:
        window.show(scene)
Esempio n. 24
0
    #mean_vol_per_tract.append(np.nanmedian(s))

hue = [0.25, -0.05]
saturation = [0, 1]
scale = [3, 12]

if downsamp != 1:
    mean_vol_per_tract = mean_vol_per_tract[::downsamp]
    str1 = str1[::downsamp]

mean_pasi_weighted_img = f'{folder_name}{os.sep}streamlines{os.sep}CC_3-12_Exp_DTI_PreReg_1_ds2_tube0p3.png'

lut_cmap = actor.colormap_lookup_table(hue_range=hue,
                                       saturation_range=saturation,
                                       scale_range=scale)

#str2 = transform_streamlines(str1, np.linalg.inv(affine2))
streamlines_actor = actor.streamtube(str1,
                                     mean_vol_per_tract,
                                     linewidth=0.3,
                                     lookup_colormap=lut_cmap)
bar = actor.scalar_bar(lut_cmap)
r = window.Scene()
r.add(streamlines_actor)
r.add(bar)
#r.SetBackground(*window.colors.white)

window.show(r)
r.set_camera(r.camera_info())
window.record(r, out_path=mean_pasi_weighted_img, size=(800, 800))
Esempio n. 25
0
def visualize_bundles(sft,
                      affine=None,
                      n_points=None,
                      bundle_dict=None,
                      bundle=None,
                      colors=None,
                      color_by_volume=None,
                      cbv_lims=[None, None],
                      figure=None,
                      background=(1, 1, 1),
                      interact=False,
                      inline=False,
                      flip_axes=None):
    """
    Visualize bundles in 3D using VTK

    Parameters
    ----------
    sft : Stateful Tractogram, str
        A Stateful Tractogram containing streamline information
        or a path to a trk file
        In order to visualize individual bundles, the Stateful Tractogram
        must contain a bundle key in it's data_per_streamline which is a list
        of bundle `'uid'`.

    affine : ndarray, optional
       An affine transformation to apply to the streamlines before
       visualization. Default: no transform.

    n_points : int or None
        n_points to resample streamlines to before plotting. If None, no
        resampling is done.

    bundle_dict : dict, optional
        Keys are names of bundles and values are dicts that should include
        a key `'uid'` with values as integers for selection from the sft
        metadata. Default: bundles are either not identified, or identified
        only as unique integers in the metadata.

    bundle : str or int, optional
        The name of a bundle to select from among the keys in `bundle_dict`
        or an integer for selection from the sft metadata.

    colors : dict or list
        If this is a dict, keys are bundle names and values are RGB tuples.
        If this is a list, each item is an RGB tuple. Defaults to a list
        with Tableau 20 RGB values if bundle_dict is None, or dict from
        bundles to Tableau 20 RGB values if bundle_dict is not None.

    color_by_volume : ndarray or str, optional
        3d volume use to shade the bundles. If None, no shading
        is performed. Only works when using the plotly backend.
        Default: None

    cbv_lims : ndarray
        Of the form (lower bound, upper bound). Shading based on
        color_by_volume will only differentiate values within these bounds.
        If lower bound is None, will default to 0.
        If upper bound is None, will default to the maximum value in
        color_by_volume.
        Default: [None, None]

    background : tuple, optional
        RGB values for the background. Default: (1, 1, 1), which is white
        background.

    figure : fury Scene object, optional
        If provided, the visualization will be added to this Scene. Default:
        Initialize a new Scene.

    interact : bool
        Whether to provide an interactive VTK window for interaction.
        Default: False

    inline : bool
        Whether to embed the visualization inline in a notebook. Only works
        in the notebook context. Default: False.

    flip_axes : None
        This parameter is to conform fury and plotly APIs.

    Returns
    -------
    Fury Scene object
    """

    if figure is None:
        figure = window.Scene()

    figure.SetBackground(background[0], background[1], background[2])

    for (sls, color, name, _) in vut.tract_generator(sft, affine, bundle,
                                                     bundle_dict, colors,
                                                     n_points):
        sls = list(sls)
        if name == "all_bundles":
            color = line_colors(sls)

        sl_actor = actor.line(sls, color)
        figure.add(sl_actor)
        sl_actor.GetProperty().SetRenderLinesAsTubes(1)
        sl_actor.GetProperty().SetLineWidth(6)

    return _inline_interact(figure, inline, interact)
Esempio n. 26
0
def visualize_volume(volume,
                     x=None,
                     y=None,
                     z=None,
                     figure=None,
                     flip_axes=None,
                     opacity=0.6,
                     inline=True,
                     interact=False):
    """
    Visualize a volume

    Parameters
    ----------
    volume : ndarray or str
        3d volume to visualize.

    figure : fury Scene object, optional
        If provided, the visualization will be added to this Scene. Default:
        Initialize a new Scene.

    flip_axes : None
        This parameter is to conform fury and plotly APIs.

    opacity : float, optional
        Initial opacity of slices.
        Default: 0.6

    interact : bool
        Whether to provide an interactive VTK window for interaction.
        Default: False

    inline : bool
        Whether to embed the visualization inline in a notebook. Only works
        in the notebook context. Default: False.

    Returns
    -------
    Fury Scene object
    """
    volume = vut.load_volume(volume)

    if figure is None:
        figure = window.Scene()

    shape = volume.shape
    image_actor_z = actor.slicer(volume)
    slicer_opacity = opacity
    image_actor_z.opacity(slicer_opacity)

    image_actor_x = image_actor_z.copy()
    if x is None:
        x = int(np.round(shape[0] / 2))
    image_actor_x.display_extent(x, x, 0, shape[1] - 1, 0, shape[2] - 1)

    image_actor_y = image_actor_z.copy()

    if y is None:
        y = int(np.round(shape[1] / 2))
    image_actor_y.display_extent(0, shape[0] - 1, y, y, 0, shape[2] - 1)

    figure.add(image_actor_z)
    figure.add(image_actor_x)
    figure.add(image_actor_y)

    show_m = window.ShowManager(figure, size=(1200, 900))
    show_m.initialize()

    if interact:
        line_slider_z = ui.LineSlider2D(min_value=0,
                                        max_value=shape[2] - 1,
                                        initial_value=shape[2] / 2,
                                        text_template="{value:.0f}",
                                        length=140)

        line_slider_x = ui.LineSlider2D(min_value=0,
                                        max_value=shape[0] - 1,
                                        initial_value=shape[0] / 2,
                                        text_template="{value:.0f}",
                                        length=140)

        line_slider_y = ui.LineSlider2D(min_value=0,
                                        max_value=shape[1] - 1,
                                        initial_value=shape[1] / 2,
                                        text_template="{value:.0f}",
                                        length=140)

        opacity_slider = ui.LineSlider2D(min_value=0.0,
                                         max_value=1.0,
                                         initial_value=slicer_opacity,
                                         length=140)

        def change_slice_z(slider):
            z = int(np.round(slider.value))
            image_actor_z.display_extent(0, shape[0] - 1, 0, shape[1] - 1, z,
                                         z)

        def change_slice_x(slider):
            x = int(np.round(slider.value))
            image_actor_x.display_extent(x, x, 0, shape[1] - 1, 0,
                                         shape[2] - 1)

        def change_slice_y(slider):
            y = int(np.round(slider.value))
            image_actor_y.display_extent(0, shape[0] - 1, y, y, 0,
                                         shape[2] - 1)

        def change_opacity(slider):
            slicer_opacity = slider.value
            image_actor_z.opacity(slicer_opacity)
            image_actor_x.opacity(slicer_opacity)
            image_actor_y.opacity(slicer_opacity)

        line_slider_z.on_change = change_slice_z
        line_slider_x.on_change = change_slice_x
        line_slider_y.on_change = change_slice_y
        opacity_slider.on_change = change_opacity

        def build_label(text):
            label = ui.TextBlock2D()
            label.message = text
            label.font_size = 18
            label.font_family = 'Arial'
            label.justification = 'left'
            label.bold = False
            label.italic = False
            label.shadow = False
            label.background = (0, 0, 0)
            label.color = (1, 1, 1)

            return label

        line_slider_label_z = build_label(text="Z Slice")
        line_slider_label_x = build_label(text="X Slice")
        line_slider_label_y = build_label(text="Y Slice")
        opacity_slider_label = build_label(text="Opacity")

        panel = ui.Panel2D(size=(300, 200),
                           color=(1, 1, 1),
                           opacity=0.1,
                           align="right")
        panel.center = (1030, 120)

        panel.add_element(line_slider_label_x, (0.1, 0.75))
        panel.add_element(line_slider_x, (0.38, 0.75))
        panel.add_element(line_slider_label_y, (0.1, 0.55))
        panel.add_element(line_slider_y, (0.38, 0.55))
        panel.add_element(line_slider_label_z, (0.1, 0.35))
        panel.add_element(line_slider_z, (0.38, 0.35))
        panel.add_element(opacity_slider_label, (0.1, 0.15))
        panel.add_element(opacity_slider, (0.38, 0.15))

        show_m.scene.add(panel)

        global size
        size = figure.GetSize()

        def win_callback(obj, event):
            global size
            if size != obj.GetSize():
                size_old = size
                size = obj.GetSize()
                size_change = [size[0] - size_old[0], 0]
                panel.re_align(size_change)

    show_m.initialize()

    figure.zoom(1.5)
    figure.reset_clipping_range()

    if interact:
        show_m.add_window_callback(win_callback)
        show_m.render()
        show_m.start()

    return _inline_interact(figure, inline, interact)
Esempio n. 27
0
def plot_tracts(classes,
                bundle_segmentations,
                affine,
                out_dir,
                brain_mask=None):
    """
    By default this does not work on a remote server connection (ssh -X) because -X does not support OpenGL.
    On the remote Server you can do 'export DISPLAY=":0"' .
    (you should set the value you get if you do 'echo $DISPLAY' if you
    login locally on the remote server). Then all graphics will get rendered locally and not via -X.
    (important: graphical session needs to be running on remote server (e.g. via login locally))
    (important: login needed, not just stay at login screen)

    If running on a headless server without Display using Xvfb might help:
    https://stackoverflow.com/questions/6281998/can-i-run-glu-opengl-on-a-headless-server
    """
    from dipy.viz import window
    from tractseg.libs import vtk_utils

    SMOOTHING = 10
    WINDOW_SIZE = (800, 800)
    bundles = ["CST_right", "CA", "IFO_right"]

    renderer = window.Scene()
    renderer.projection('parallel')

    rows = len(bundles)
    X, Y, Z = bundle_segmentations.shape[:3]
    for j, bundle in enumerate(bundles):
        i = 0  #only one method

        bundle_idx = dataset_specific_utils.get_bundle_names(
            classes)[1:].index(bundle)
        mask_data = bundle_segmentations[:, :, :, bundle_idx]

        if bundle == "CST_right":
            orientation = "axial"
        elif bundle == "CA":
            orientation = "axial"
        elif bundle == "IFO_right":
            orientation = "sagittal"
        else:
            orientation = "axial"

        #bigger: more border
        if orientation == "axial":
            border_y = -100
        else:
            border_y = -100

        x_current = X * i  # column (width)
        y_current = rows * (Y * 2 + border_y) - (
            Y * 2 + border_y) * j  # row (height)  (starts from bottom)

        plot_mask(renderer,
                  mask_data,
                  affine,
                  x_current,
                  y_current,
                  orientation=orientation,
                  smoothing=SMOOTHING,
                  brain_mask=brain_mask)

        #Bundle label
        text_offset_top = -50
        text_offset_side = -100
        position = (0 - int(X) + text_offset_side, y_current + text_offset_top,
                    50)
        text_actor = vtk_utils.label(text=bundle,
                                     pos=position,
                                     scale=(6, 6, 6),
                                     color=(1, 1, 1))
        renderer.add(text_actor)

    renderer.reset_camera()
    window.record(renderer,
                  out_path=join(out_dir, "preview.png"),
                  size=(WINDOW_SIZE[0], WINDOW_SIZE[1]),
                  reset_camera=False,
                  magnification=2)
Esempio n. 28
0
def visualize_bundles(trk,
                      affine=None,
                      bundle_dict=None,
                      bundle=None,
                      colors=None,
                      scene=None,
                      background=(1, 1, 1),
                      interact=False,
                      inline=False):
    """
    Visualize bundles in 3D using VTK

    Parameters
    ----------
    trk : str, list, or Streamlines
        The streamline information

    affine : ndarray, optional
       An affine transformation to apply to the streamlines before
       visualization. Default: no transform.

    bundle_dict : dict, optional
        Keys are names of bundles and values are dicts that should include
        a key `'uid'` with values as integers for selection from the trk
        metadata. Default: bundles are either not identified, or identified
        only as unique integers in the metadata.

    bundle : str or int, optional
        The name of a bundle to select from among the keys in `bundle_dict`
        or an integer for selection from the trk metadata.

    colors : dict or list
        If this is a dict, keys are bundle names and values are RGB tuples.
        If this is a list, each item is an RGB tuple. Defaults to a list
        with Tableau 20 RGB values

    background : tuple, optional
        RGB values for the background. Default: (1, 1, 1), which is white
        background.

    scene : fury Scene object, optional
        If provided, the visualization will be added to this Scene. Default:
        Initialize a new Scene.

    interact : bool
        Whether to provide an interactive VTK window for interaction.
        Default: False

    inline : bool
        Whether to embed the visualization inline in a notebook. Only works
        in the notebook context. Default: False.

    Returns
    -------
    Fury Scene object
    """
    if isinstance(trk, str):
        trk = nib.streamlines.load(trk)
        tg = trk.tractogram
    else:
        # Assume these are streamlines (as list or Streamlines object):
        tg = nib.streamlines.Tractogram(trk)

    if affine is not None:
        tg = tg.apply_affine(np.linalg.inv(affine))

    streamlines = tg.streamlines

    if scene is None:
        scene = window.Scene()

    scene.SetBackground(background[0], background[1], background[2])

    if colors is None:
        # Use the color dict provided
        colors = color_dict

    def _color_selector(bundle_dict, colors, b):
        """Helper function """
        if bundle_dict is None:
            # We'll choose a color from a rotating list:
            if isinstance(colors, list):
                color = colors[np.mod(len(colors), int(b))]
            else:
                color_list = colors.values()
                color = color_list[np.mod(len(colors), int(b))]
        else:
            # We have a mapping from UIDs to bundle names:
            for b_name_iter, b_iter in bundle_dict.items():
                if b_iter['uid'] == b:
                    b_name = b_name_iter
                    break
            color = colors[b_name]
        return color

    if list(tg.data_per_streamline.keys()) == []:
        # There are no bundles in here:
        streamlines = list(streamlines)
        # Visualize all the streamlines with directionally assigned RGB:
        sl_actor = actor.line(streamlines, line_colors(streamlines))
        scene.add(sl_actor)
        sl_actor.GetProperty().SetRenderLinesAsTubes(1)
        sl_actor.GetProperty().SetLineWidth(6)

    else:
        # There are bundles:
        if bundle is None:
            # No selection: visualize all of them:
            for b in np.unique(tg.data_per_streamline['bundle']):
                idx = np.where(tg.data_per_streamline['bundle'] == b)[0]
                this_sl = list(streamlines[idx])
                color = _color_selector(bundle_dict, colors, b)
                sl_actor = actor.line(this_sl, color)
                sl_actor.GetProperty().SetRenderLinesAsTubes(1)
                sl_actor.GetProperty().SetLineWidth(6)
                scene.add(sl_actor)

        else:
            # Select just one to visualize:
            if isinstance(bundle, str):
                # We need to find the UID:
                uid = bundle_dict[bundle]['uid']
            else:
                # It's already a UID:
                uid = bundle

            idx = np.where(tg.data_per_streamline['bundle'] == uid)[0]
            this_sl = list(streamlines[idx])
            color = _color_selector(bundle_dict, colors, uid)
            sl_actor = actor.line(this_sl, color)
            sl_actor.GetProperty().SetRenderLinesAsTubes(1)
            sl_actor.GetProperty().SetLineWidth(6)
            scene.add(sl_actor)

    return _inline_interact(scene, inline, interact)
Esempio n. 29
0
fod = csd_fit.odf(small_sphere)
pmf = fod.clip(min=0)
prob_dg = ProbabilisticDirectionGetter.from_pmf(pmf,
                                                max_angle=30.,
                                                sphere=small_sphere)
streamline_generator = LocalTracking(prob_dg,
                                     stopping_criterion,
                                     seeds,
                                     affine,
                                     step_size=.5)
streamlines = Streamlines(streamline_generator)
sft = StatefulTractogram(streamlines, hardi_img, Space.RASMM)
save_trk(sft, "tractogram_probabilistic_dg_pmf.trk")

if has_fury:
    scene = window.Scene()
    scene.add(actor.line(streamlines, colormap.line_colors(streamlines)))
    window.record(scene,
                  out_path='tractogram_probabilistic_dg_pmf.png',
                  size=(800, 800))
    if interactive:
        window.show(scene)
"""
.. figure:: tractogram_probabilistic_dg_pmf.png
   :align: center

   **Corpus Callosum using probabilistic direction getter from PMF**
"""
"""
One disadvantage of using a discrete PMF to represent possible tracking
directions is that it tends to take up a lot of memory (RAM). The size of the
def dMRI2ODF_DTI(PATH):
    '''
    Input the dMRI data
    return the ODF
    '''
    dMRI_path = PATH + 'data.nii.gz'
    mask_path = PATH + 'nodif_brain_mask.nii.gz'
    dMRI_img = nib.load(dMRI_path)
    dMRI_data = dMRI_img.get_fdata()
    mask_img = nib.load(mask_path)
    mask = mask_img.get_fdata()

    ########## subsample ##########
    # dMRI_data = dMRI_data[45:-48,50:-65,51:-54,...]
    # mask = mask[45:-48,50:-65,51:-54]
    # breakpoint()
    dMRI_data = dMRI_data[:, 87, ...]
    mask = mask[:, 87, ...]

    for cnt in range(10):
        fig = plt.imshow(dMRI_data[:, :, cnt].transpose(1, 0),
                         cmap='Greys',
                         interpolation='nearest')
        plt.axis('off')
        # plt.imshow(dMRI_data[:,15,:,cnt].transpose(1,0),cmap='Greys')
        plt.savefig(str(cnt) + '.png',
                    bbox_inches='tight',
                    dpi=300,
                    transparent=True,
                    pad_inches=0)

    # breakpoint()
    bval = PATH + "bvals"
    bvec = PATH + "bvecs"

    radial_order = 6
    zeta = 700
    lambdaN = 1e-8
    lambdaL = 1e-8

    gtab = gradient_table(bvals=bval, bvecs=bvec)
    asm = ShoreModel(gtab,
                     radial_order=radial_order,
                     zeta=zeta,
                     lambdaN=lambdaN,
                     lambdaL=lambdaL)
    asmfit = asm.fit(dMRI_data, mask=mask)
    sphere = get_sphere('symmetric362')
    dMRI_odf = asmfit.odf(sphere)
    dMRI_odf[dMRI_odf <= 0] = 0

    tenmodel = dti.TensorModel(gtab)
    tenfit = tenmodel.fit(dMRI_data, mask)
    dMRI_dti = tenfit.quadratic_form

    FA = fractional_anisotropy(tenfit.evals)
    FA[np.isnan(FA)] = 0
    FA = np.clip(FA, 0, 1)
    RGB = color_fa(FA, tenfit.evecs)

    evals = tenfit.evals + 1e-20
    evecs = tenfit.evecs
    cfa = RGB + 1e-20
    cfa /= cfa.max()

    evals = np.expand_dims(evals, 2)
    evecs = np.expand_dims(evecs, 2)
    cfa = np.expand_dims(cfa, 2)

    ren = window.Scene()
    sphere = get_sphere('symmetric362')
    ren.add(
        actor.tensor_slicer(evals,
                            evecs,
                            scalar_colors=cfa,
                            sphere=sphere,
                            scale=0.5))
    window.record(ren,
                  n_frames=1,
                  out_path='../data/tensor.png',
                  size=(5000, 5000))

    odf_ = dMRI_odf

    ren = window.Scene()
    sfu = actor.odf_slicer(np.expand_dims(odf_, 2),
                           sphere=sphere,
                           colormap="plasma",
                           scale=0.5)

    ren.add(sfu)
    window.record(ren,
                  n_frames=1,
                  out_path='../data/odfs.png',
                  size=(5000, 5000))

    return None