Exemple #1
0
def generate_points():
    centers = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])

    colors = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) * 255

    vtk_vertices = Points()
    # Create the topology of the point (a vertex)
    vtk_faces = CellArray()
    # Add points
    for i in range(len(centers)):
        p = centers[i]
        id = vtk_vertices.InsertNextPoint(p)
        vtk_faces.InsertNextCell(1)
        vtk_faces.InsertCellPoint(id)
    # Create a polydata object
    polydata = PolyData()
    # Set the vertices and faces we created as the geometry and topology of the
    # polydata
    polydata.SetPoints(vtk_vertices)
    polydata.SetVerts(vtk_faces)

    set_polydata_colors(polydata, colors)

    mapper = PolyDataMapper()
    mapper.SetInputData(polydata)
    mapper.SetVBOShiftScaleMethod(False)

    point_actor = Actor()
    point_actor.SetMapper(mapper)

    return point_actor
Exemple #2
0
def test_shader_callback():

    cone = ConeSource()
    coneMapper = PolyDataMapper()
    coneMapper.SetInputConnection(cone.GetOutputPort())
    actor = Actor()
    actor.SetMapper(coneMapper)

    test_values = []

    def callbackLow(_caller, _event, calldata=None):
        program = calldata
        if program is not None:
            test_values.append(0)

    id_observer = fs.add_shader_callback(actor, callbackLow, 0)

    with pytest.raises(Exception):
        fs.add_shader_callback(actor, callbackLow, priority='str')

    mapper = actor.GetMapper()
    mapper.RemoveObserver(id_observer)

    scene = window.Scene()
    scene.add(actor)

    arr1 = window.snapshot(scene, size=(200, 200))
    assert len(test_values) == 0

    test_values = []

    def callbackHigh(_caller, _event, calldata=None):
        program = calldata
        if program is not None:
            test_values.append(999)

    def callbackMean(_caller, _event, calldata=None):
        program = calldata
        if program is not None:
            test_values.append(500)

    fs.add_shader_callback(actor, callbackHigh, 999)
    fs.add_shader_callback(actor, callbackLow, 0)

    id_mean = fs.add_shader_callback(actor, callbackMean, 500)

    # check the priority of each call
    arr2 = window.snapshot(scene, size=(200, 200))
    assert np.abs(
        [test_values[0] - 999, test_values[1] - 500,
         test_values[2] - 0]).sum() == 0

    # check if the correct observer was removed
    mapper.RemoveObserver(id_mean)
    test_values = []

    arr3 = window.snapshot(scene, size=(200, 200))
    assert np.abs([test_values[0] - 999, test_values[1] - 0]).sum() == 0
Exemple #3
0
    def __init__(self,
                 odfs,
                 vertices,
                 faces,
                 indices,
                 scale,
                 norm,
                 radial_scale,
                 shape,
                 global_cm,
                 colormap,
                 opacity,
                 affine=None,
                 B=None):
        self.vertices = vertices
        self.faces = faces
        self.odfs = odfs
        self.indices = indices
        self.B = B
        self.radial_scale = radial_scale
        self.colormap = colormap
        self.grid_shape = shape
        self.global_cm = global_cm

        # declare a mask to be instantiated in slice_along_axis
        self.mask = None

        # If a B matrix is given, odfs are expected to
        # be in SH basis coefficients.
        if self.B is not None:
            # In that case, we need to save our normalisation and scale
            # to apply them after conversion from SH to SF.
            self.norm = norm
            self.scale = scale
        else:
            # If our input is in SF coefficients, we can normalise and
            # scale it only once, here.
            if norm:
                self.odfs /= np.abs(self.odfs).max(axis=-1, keepdims=True)
            self.odfs *= scale

        # Compute world coordinates of an affine is supplied
        self.affine = affine
        if self.affine is not None:
            self.w_verts = self.vertices.dot(affine[:3, :3])
            self.w_pos = apply_affine(affine, np.asarray(self.indices).T)

        # Initialize mapper and slice to the
        # middle of the volume along Z axis
        self.mapper = PolyDataMapper()
        self.SetMapper(self.mapper)
        self.slice_along_axis(self.grid_shape[-1] // 2)
        self.set_opacity(opacity)
Exemple #4
0
def get_polymapper_from_polydata(polydata):
    """Get vtkPolyDataMapper from a vtkPolyData.

    Parameters
    ----------
    polydata : vtkPolyData

    Returns
    -------
    poly_mapper : vtkPolyDataMapper

    """
    poly_mapper = set_input(PolyDataMapper(), polydata)
    poly_mapper.ScalarVisibilityOn()
    poly_mapper.InterpolateScalarsBeforeMappingOn()
    poly_mapper.Update()
    poly_mapper.StaticOn()
    return poly_mapper
Exemple #5
0
def repeat_sources(centers,
                   colors,
                   active_scalars=1.,
                   directions=None,
                   source=None,
                   vertices=None,
                   faces=None,
                   orientation=None):
    """Transform a vtksource to glyph."""
    if source is None and faces is None:
        raise IOError("A source or faces should be defined")

    if np.array(colors).ndim == 1:
        colors = np.tile(colors, (len(centers), 1))

    pts = numpy_to_vtk_points(np.ascontiguousarray(centers))
    cols = numpy_to_vtk_colors(255 * np.ascontiguousarray(colors))
    cols.SetName('colors')
    if isinstance(active_scalars, (float, int)):
        active_scalars = np.tile(active_scalars, (len(centers), 1))
    if isinstance(active_scalars, np.ndarray):
        ascalars = numpy_support.numpy_to_vtk(np.asarray(active_scalars),
                                              deep=True,
                                              array_type=VTK_DOUBLE)
        ascalars.SetName('active_scalars')

    if directions is not None:
        directions_fa = numpy_support.numpy_to_vtk(np.asarray(directions),
                                                   deep=True,
                                                   array_type=VTK_DOUBLE)
        directions_fa.SetName('directions')

    polydata_centers = PolyData()
    polydata_geom = PolyData()

    if faces is not None:
        set_polydata_vertices(polydata_geom, vertices)
        set_polydata_triangles(polydata_geom, faces)

    polydata_centers.SetPoints(pts)
    polydata_centers.GetPointData().AddArray(cols)
    if directions is not None:
        polydata_centers.GetPointData().AddArray(directions_fa)
        polydata_centers.GetPointData().SetActiveVectors('directions')
    if isinstance(active_scalars, np.ndarray):
        polydata_centers.GetPointData().AddArray(ascalars)
        polydata_centers.GetPointData().SetActiveScalars('active_scalars')

    glyph = Glyph3D()
    if faces is None:
        if orientation is not None:
            transform = Transform()
            transform.SetMatrix(numpy_to_vtk_matrix(orientation))
            rtrans = TransformPolyDataFilter()
            rtrans.SetInputConnection(source.GetOutputPort())
            rtrans.SetTransform(transform)
            source = rtrans
        glyph.SetSourceConnection(source.GetOutputPort())
    else:
        glyph.SetSourceData(polydata_geom)
    glyph.SetInputData(polydata_centers)
    glyph.SetOrient(True)
    glyph.SetScaleModeToScaleByScalar()
    glyph.SetVectorModeToUseVector()
    glyph.Update()

    mapper = PolyDataMapper()
    mapper.SetInputData(glyph.GetOutput())
    mapper.SetScalarModeToUsePointFieldData()
    mapper.SelectColorArray('colors')

    actor = Actor()
    actor.SetMapper(mapper)
    return actor
Exemple #6
0
    def __init__(self,
                 directions,
                 indices,
                 values=None,
                 affine=None,
                 colors=None,
                 lookup_colormap=None,
                 linewidth=1,
                 symmetric=True):
        if affine is not None:
            w_pos = apply_affine(affine, np.asarray(indices).T)

        valid_dirs = directions[indices]

        num_dirs = len(np.nonzero(np.abs(valid_dirs).max(axis=-1) > 0)[0])

        pnts_per_line = 2

        points_array = np.empty((num_dirs * pnts_per_line, 3))
        centers_array = np.empty_like(points_array, dtype=int)
        diffs_array = np.empty_like(points_array)
        line_count = 0
        for idx, center in enumerate(zip(indices[0], indices[1], indices[2])):
            if affine is None:
                xyz = np.asarray(center)
            else:
                xyz = w_pos[idx, :]
            valid_peaks = np.nonzero(
                np.abs(valid_dirs[idx, :, :]).max(axis=-1) > 0.)[0]
            for direction in valid_peaks:
                if values is not None:
                    pv = values[center][direction]
                else:
                    pv = 1.

                if symmetric:
                    point_i = directions[center][direction] * pv + xyz
                    point_e = -directions[center][direction] * pv + xyz
                else:
                    point_i = directions[center][direction] * pv + xyz
                    point_e = xyz

                diff = point_e - point_i
                points_array[line_count * pnts_per_line, :] = point_e
                points_array[line_count * pnts_per_line + 1, :] = point_i
                centers_array[line_count * pnts_per_line, :] = center
                centers_array[line_count * pnts_per_line + 1, :] = center
                diffs_array[line_count * pnts_per_line, :] = diff
                diffs_array[line_count * pnts_per_line + 1, :] = diff
                line_count += 1

        vtk_points = numpy_to_vtk_points(points_array)

        vtk_cells = _points_to_vtk_cells(points_array)

        colors_tuple = _peaks_colors_from_points(points_array, colors=colors)
        vtk_colors, colors_are_scalars, self.__global_opacity = colors_tuple

        poly_data = PolyData()
        poly_data.SetPoints(vtk_points)
        poly_data.SetLines(vtk_cells)
        poly_data.GetPointData().SetScalars(vtk_colors)

        self.__mapper = PolyDataMapper()
        self.__mapper.SetInputData(poly_data)
        self.__mapper.ScalarVisibilityOn()
        self.__mapper.SetScalarModeToUsePointFieldData()
        self.__mapper.SelectColorArray('colors')
        self.__mapper.Update()

        self.SetMapper(self.__mapper)

        attribute_to_actor(self, centers_array, 'center')
        attribute_to_actor(self, diffs_array, 'diff')

        vs_dec_code = import_fury_shader('peak_dec.vert')
        vs_impl_code = import_fury_shader('peak_impl.vert')
        fs_dec_code = import_fury_shader('peak_dec.frag')
        fs_impl_code = import_fury_shader('peak_impl.frag')

        shader_to_actor(self,
                        'vertex',
                        decl_code=vs_dec_code,
                        impl_code=vs_impl_code)
        shader_to_actor(self, 'fragment', decl_code=fs_dec_code)
        shader_to_actor(self,
                        'fragment',
                        impl_code=fs_impl_code,
                        block='light')

        # Color scale with a lookup table
        if colors_are_scalars:
            if lookup_colormap is None:
                lookup_colormap = colormap_lookup_table()

            self.__mapper.SetLookupTable(lookup_colormap)
            self.__mapper.UseLookupTableScalarRangeOn()
            self.__mapper.Update()

        self.__lw = linewidth
        self.GetProperty().SetLineWidth(self.__lw)

        if self.__global_opacity >= 0:
            self.GetProperty().SetOpacity(self.__global_opacity)

        self.__min_centers = np.min(indices, axis=1)
        self.__max_centers = np.max(indices, axis=1)

        self.__is_range = True
        self.__low_ranges = self.__min_centers
        self.__high_ranges = self.__max_centers
        self.__cross_section = self.__high_ranges // 2

        self.__mapper.AddObserver(Command.UpdateShaderEvent,
                                  self.__display_peaks_vtk_callback)
Exemple #7
0
class PeakActor(Actor):
    """VTK actor for visualizing slices of ODF field.

    Parameters
    ----------
    directions : ndarray
        Peak directions. The shape of the array should be (X, Y, Z, D, 3).
    indices : tuple
        Indices given in tuple(x_indices, y_indices, z_indices)
        format for mapping 2D ODF array to 3D voxel grid.
    values : ndarray, optional
        Peak values. The shape of the array should be (X, Y, Z, D).
    affine : array, optional
        4x4 transformation array from native coordinates to world coordinates.
    colors : None or string ('rgb_standard') or tuple (3D or 4D) or
             array/ndarray (N, 3 or 4) or array/ndarray (K, 3 or 4) or
             array/ndarray(N, ) or array/ndarray (K, )
        If None a standard orientation colormap is used for every line.
        If one tuple of color is used. Then all streamlines will have the same
        color.
        If an array (N, 3 or 4) is given, where N is equal to the number of
        points. Then every point is colored with a different RGB(A) color.
        If an array (K, 3 or 4) is given, where K is equal to the number of
        lines. Then every line is colored with a different RGB(A) color.
        If an array (N, ) is given, where N is the number of points then these
        are considered as the values to be used by the colormap.
        If an array (K,) is given, where K is the number of lines then these
        are considered as the values to be used by the colormap.
    lookup_colormap : vtkLookupTable, optional
        Add a default lookup table to the colormap. Default is None which calls
        :func:`fury.actor.colormap_lookup_table`.
    linewidth : float, optional
        Line thickness. Default is 1.
    symmetric: bool, optional
        If True, peaks are drawn for both peaks_dirs and -peaks_dirs. Else,
        peaks are only drawn for directions given by peaks_dirs. Default is
        True.

    """
    def __init__(self,
                 directions,
                 indices,
                 values=None,
                 affine=None,
                 colors=None,
                 lookup_colormap=None,
                 linewidth=1,
                 symmetric=True):
        if affine is not None:
            w_pos = apply_affine(affine, np.asarray(indices).T)

        valid_dirs = directions[indices]

        num_dirs = len(np.nonzero(np.abs(valid_dirs).max(axis=-1) > 0)[0])

        pnts_per_line = 2

        points_array = np.empty((num_dirs * pnts_per_line, 3))
        centers_array = np.empty_like(points_array, dtype=int)
        diffs_array = np.empty_like(points_array)
        line_count = 0
        for idx, center in enumerate(zip(indices[0], indices[1], indices[2])):
            if affine is None:
                xyz = np.asarray(center)
            else:
                xyz = w_pos[idx, :]
            valid_peaks = np.nonzero(
                np.abs(valid_dirs[idx, :, :]).max(axis=-1) > 0.)[0]
            for direction in valid_peaks:
                if values is not None:
                    pv = values[center][direction]
                else:
                    pv = 1.

                if symmetric:
                    point_i = directions[center][direction] * pv + xyz
                    point_e = -directions[center][direction] * pv + xyz
                else:
                    point_i = directions[center][direction] * pv + xyz
                    point_e = xyz

                diff = point_e - point_i
                points_array[line_count * pnts_per_line, :] = point_e
                points_array[line_count * pnts_per_line + 1, :] = point_i
                centers_array[line_count * pnts_per_line, :] = center
                centers_array[line_count * pnts_per_line + 1, :] = center
                diffs_array[line_count * pnts_per_line, :] = diff
                diffs_array[line_count * pnts_per_line + 1, :] = diff
                line_count += 1

        vtk_points = numpy_to_vtk_points(points_array)

        vtk_cells = _points_to_vtk_cells(points_array)

        colors_tuple = _peaks_colors_from_points(points_array, colors=colors)
        vtk_colors, colors_are_scalars, self.__global_opacity = colors_tuple

        poly_data = PolyData()
        poly_data.SetPoints(vtk_points)
        poly_data.SetLines(vtk_cells)
        poly_data.GetPointData().SetScalars(vtk_colors)

        self.__mapper = PolyDataMapper()
        self.__mapper.SetInputData(poly_data)
        self.__mapper.ScalarVisibilityOn()
        self.__mapper.SetScalarModeToUsePointFieldData()
        self.__mapper.SelectColorArray('colors')
        self.__mapper.Update()

        self.SetMapper(self.__mapper)

        attribute_to_actor(self, centers_array, 'center')
        attribute_to_actor(self, diffs_array, 'diff')

        vs_dec_code = import_fury_shader('peak_dec.vert')
        vs_impl_code = import_fury_shader('peak_impl.vert')
        fs_dec_code = import_fury_shader('peak_dec.frag')
        fs_impl_code = import_fury_shader('peak_impl.frag')

        shader_to_actor(self,
                        'vertex',
                        decl_code=vs_dec_code,
                        impl_code=vs_impl_code)
        shader_to_actor(self, 'fragment', decl_code=fs_dec_code)
        shader_to_actor(self,
                        'fragment',
                        impl_code=fs_impl_code,
                        block='light')

        # Color scale with a lookup table
        if colors_are_scalars:
            if lookup_colormap is None:
                lookup_colormap = colormap_lookup_table()

            self.__mapper.SetLookupTable(lookup_colormap)
            self.__mapper.UseLookupTableScalarRangeOn()
            self.__mapper.Update()

        self.__lw = linewidth
        self.GetProperty().SetLineWidth(self.__lw)

        if self.__global_opacity >= 0:
            self.GetProperty().SetOpacity(self.__global_opacity)

        self.__min_centers = np.min(indices, axis=1)
        self.__max_centers = np.max(indices, axis=1)

        self.__is_range = True
        self.__low_ranges = self.__min_centers
        self.__high_ranges = self.__max_centers
        self.__cross_section = self.__high_ranges // 2

        self.__mapper.AddObserver(Command.UpdateShaderEvent,
                                  self.__display_peaks_vtk_callback)

    @calldata_type(VTK_OBJECT)
    def __display_peaks_vtk_callback(self, caller, event, calldata=None):
        if calldata is not None:
            calldata.SetUniformi('isRange', self.__is_range)
            calldata.SetUniform3f('highRanges', self.__high_ranges)
            calldata.SetUniform3f('lowRanges', self.__low_ranges)
            calldata.SetUniform3f('crossSection', self.__cross_section)

    def display_cross_section(self, x, y, z):
        if self.__is_range:
            self.__is_range = False
        self.__cross_section = [x, y, z]

    def display_extent(self, x1, x2, y1, y2, z1, z2):
        if not self.__is_range:
            self.__is_range = True
        self.__low_ranges = [x1, y1, z1]
        self.__high_ranges = [x2, y2, z2]

    @property
    def cross_section(self):
        return self.__cross_section

    @property
    def global_opacity(self):
        return self.__global_opacity

    @global_opacity.setter
    def global_opacity(self, opacity):
        self.__global_opacity = opacity
        self.GetProperty().SetOpacity(self.__global_opacity)

    @property
    def high_ranges(self):
        return self.__high_ranges

    @property
    def is_range(self):
        return self.__is_range

    @property
    def low_ranges(self):
        return self.__low_ranges

    @property
    def linewidth(self):
        return self.__lw

    @linewidth.setter
    def linewidth(self, linewidth):
        self.__lw = linewidth
        self.GetProperty().SetLineWidth(self.__lw)

    @property
    def max_centers(self):
        return self.__max_centers

    @property
    def min_centers(self):
        return self.__min_centers
Exemple #8
0
class OdfSlicerActor(Actor):
    """
    VTK actor for visualizing slices of ODF field.

    Parameters
    ----------
    odfs : ndarray
        SF or SH coefficients 2-dimensional array.
    vertices: ndarray
        The sphere vertices used for SH to SF projection.
    faces: ndarray
        Indices of sphere vertices forming triangles. Should be
        ordered clockwise (see fury.utils.fix_winding_order).
    indices: tuple
        Indices given in tuple(x_indices, y_indices, z_indices)
        format for mapping 2D ODF array to 3D voxel grid.
    scale : float
        Multiplicative factor to apply to ODF amplitudes.
    norm : bool
        Normalize SF amplitudes so that the maximum
        ODF amplitude per voxel along a direction is 1.
    radial_scale : bool
        Scale sphere points by ODF values.
    global_cm : bool
        If True the colormap will be applied in all ODFs. If False
        it will be applied individually at each voxel.
    colormap : None or str
        The name of the colormap to use. Matplotlib colormaps are supported
        (e.g., 'inferno'). If None then a RGB colormap is used.
    opacity : float
        Takes values from 0 (fully transparent) to 1 (opaque).
    affine : array
        optional 4x4 transformation array from native
        coordinates to world coordinates.
    B : ndarray (n_coeffs, n_vertices)
        Optional SH to SF matrix for projecting `odfs` given in SH
        coefficents on the `sphere`. If None, then the input is assumed
        to be expressed in SF coefficients.
    """
    def __init__(self,
                 odfs,
                 vertices,
                 faces,
                 indices,
                 scale,
                 norm,
                 radial_scale,
                 shape,
                 global_cm,
                 colormap,
                 opacity,
                 affine=None,
                 B=None):
        self.vertices = vertices
        self.faces = faces
        self.odfs = odfs
        self.indices = indices
        self.B = B
        self.radial_scale = radial_scale
        self.colormap = colormap
        self.grid_shape = shape
        self.global_cm = global_cm

        # declare a mask to be instantiated in slice_along_axis
        self.mask = None

        # If a B matrix is given, odfs are expected to
        # be in SH basis coefficients.
        if self.B is not None:
            # In that case, we need to save our normalisation and scale
            # to apply them after conversion from SH to SF.
            self.norm = norm
            self.scale = scale
        else:
            # If our input is in SF coefficients, we can normalise and
            # scale it only once, here.
            if norm:
                self.odfs /= np.abs(self.odfs).max(axis=-1, keepdims=True)
            self.odfs *= scale

        # Compute world coordinates of an affine is supplied
        self.affine = affine
        if self.affine is not None:
            self.w_verts = self.vertices.dot(affine[:3, :3])
            self.w_pos = apply_affine(affine, np.asarray(self.indices).T)

        # Initialize mapper and slice to the
        # middle of the volume along Z axis
        self.mapper = PolyDataMapper()
        self.SetMapper(self.mapper)
        self.slice_along_axis(self.grid_shape[-1] // 2)
        self.set_opacity(opacity)

    def set_opacity(self, opacity):
        """
        Set opacity value of ODFs to display.
        """
        self.GetProperty().SetOpacity(opacity)

    def display_extent(self, x1, x2, y1, y2, z1, z2):
        """
        Set visible volume from x1 (inclusive) to x2 (inclusive),
        y1 (inclusive) to y2 (inclusive), z1 (inclusive) to z2
        (inclusive).
        """
        mask = np.zeros(self.grid_shape, dtype=bool)
        mask[x1:x2 + 1, y1:y2 + 1, z1:z2 + 1] = True
        self.mask = mask

        self._update_mapper()

    def slice_along_axis(self, slice_index, axis='zaxis'):
        """
        Slice ODF field at given `slice_index` along axis
        in ['xaxis', 'yaxis', zaxis'].
        """
        if axis == 'xaxis':
            self.display_extent(slice_index, slice_index, 0,
                                self.grid_shape[1] - 1, 0,
                                self.grid_shape[2] - 1)
        elif axis == 'yaxis':
            self.display_extent(0, self.grid_shape[0] - 1, slice_index,
                                slice_index, 0, self.grid_shape[2] - 1)
        elif axis == 'zaxis':
            self.display_extent(0, self.grid_shape[0] - 1, 0,
                                self.grid_shape[1] - 1, slice_index,
                                slice_index)
        else:
            raise ValueError('Invalid axis name {0}.'.format(axis))

    def display(self, x=None, y=None, z=None):
        """
        Display a slice along x, y, or z axis.
        """
        if x is None and y is None and z is None:
            self.slice_along_axis(self.grid_shape[2] // 2)
        elif x is not None:
            self.slice_along_axis(x, 'xaxis')
        elif y is not None:
            self.slice_along_axis(y, 'yaxis')
        elif z is not None:
            self.slice_along_axis(z, 'zaxis')

    def update_sphere(self, vertices, faces, B):
        """
        Dynamically change the sphere used for SH to SF projection.
        """
        if self.B is None:
            raise ValueError('Can\'t update sphere when using '
                             'SF coefficients.')
        self.vertices = vertices
        if self.affine is not None:
            self.w_verts = self.vertices.dot(self.affine[:3, :3])
        self.faces = faces
        self.B = B

        # draw ODFs with new sphere
        self._update_mapper()

    def _update_mapper(self):
        """
        Map vtkPolyData to the actor.
        """
        polydata = PolyData()

        offsets = self._get_odf_offsets(self.mask)
        if len(offsets) == 0:
            self.mapper.SetInputData(polydata)
            return None

        sph_dirs = self._get_sphere_directions()
        sf = self._get_sf(self.mask)

        all_vertices = self._get_all_vertices(offsets, sph_dirs, sf)
        all_faces = self._get_all_faces(len(offsets), len(sph_dirs))
        all_colors = self._generate_color_for_vertices(sf)

        # TODO: There is a lot of deep copy here.
        # Optimize (see viz_network.py example).
        set_polydata_triangles(polydata, all_faces)
        set_polydata_vertices(polydata, all_vertices)
        set_polydata_colors(polydata, all_colors)

        self.mapper.SetInputData(polydata)

    def _get_odf_offsets(self, mask):
        """
        Get the position of non-zero voxels inside `mask`.
        """
        if self.affine is not None:
            return self.w_pos[mask[self.indices]]
        return np.asarray(self.indices).T[mask[self.indices]]

    def _get_sphere_directions(self):
        """
        Get the sphere directions onto which is projected the signal.
        """
        if self.affine is not None:
            return self.w_verts
        return self.vertices

    def _get_sf(self, mask):
        """
        Get SF coefficients inside `mask`.
        """
        # when odfs are expressed in SH coefficients
        if self.B is not None:
            sf = self.odfs[mask[self.indices]].dot(self.B)
            # normalisation and scaling is done on SF coefficients
            if self.norm:
                sf /= np.abs(sf).max(axis=-1, keepdims=True)
            return sf * self.scale
        # when odfs are in SF coefficients, the normalisation and scaling
        # are done during initialisation. We simply return them:
        return self.odfs[mask[self.indices]]

    def _get_all_vertices(self, offsets, sph_dirs, sf):
        """
        Get array of all the vertices of the ODFs to display.
        """
        if self.radial_scale:
            # apply SF amplitudes to all sphere
            # directions and offset each voxel
            return np.tile(sph_dirs, (len(offsets), 1)) * sf.reshape(-1, 1) +\
                   np.repeat(offsets, len(sph_dirs), axis=0)
        # return scaled spheres offsetted by `offsets`
        return np.tile(sph_dirs, (len(offsets), 1)) * self.scale +\
            np.repeat(offsets, len(sph_dirs), axis=0)

    def _get_all_faces(self, nb_odfs, nb_dirs):
        """
        Get array of all the faces of the ODFs to display.
        """
        return np.tile(self.faces, (nb_odfs, 1)) +\
            np.repeat(np.arange(nb_odfs) * nb_dirs, len(self.faces))\
            .reshape(-1, 1)

    def _generate_color_for_vertices(self, sf):
        """
        Get array of all vertices colors.
        """
        if self.global_cm:
            if self.colormap is None:
                raise IOError("if global_cm=True, colormap must be defined.")
            else:
                all_colors = create_colormap(sf.ravel(), self.colormap) * 255
        elif self.colormap is not None:
            if isinstance(self.colormap, str):
                # Map ODFs values [min, max] to [0, 1] for each ODF
                range_sf = sf.max(axis=-1) - sf.min(axis=-1)
                rescaled = sf - sf.min(axis=-1, keepdims=True)
                rescaled[range_sf > 0] /= range_sf[range_sf > 0][..., None]
                all_colors =\
                    create_colormap(rescaled.ravel(), self.colormap) * 255
            else:
                all_colors = np.tile(
                    np.array(self.colormap).reshape(1, 3),
                    (sf.shape[0] * sf.shape[1], 1))
        else:
            all_colors = np.tile(np.abs(self.vertices) * 255, (len(sf), 1))
        return all_colors.astype(np.uint8)
Exemple #9
0
def ribbon(molecule):
    """Create an actor for ribbon molecular representation.

    Parameters
    ----------
    molecule : Molecule
        The molecule to be rendered.

    Returns
    -------
    molecule_actor : vtkActor
        Actor created to render the rubbon representation of the molecule to be
        visualized.

    References
    ----------
    Richardson, J.S. The anatomy and taxonomy of protein structure
    `Advances in Protein Chemistry, 1981, 34, 167-339.
    <https://doi.org/10.1016/S0065-3233(08)60520-3>`_
    """
    coords = get_all_atomic_positions(molecule)
    all_atomic_numbers = get_all_atomic_numbers(molecule)
    num_total_atoms = molecule.total_num_atoms
    secondary_structures = np.ones(num_total_atoms)
    for i in range(num_total_atoms):
        secondary_structures[i] = ord('c')
        resi = molecule.residue_seq[i]
        for j, _ in enumerate(molecule.sheet):
            sheet = molecule.sheet[j]
            if molecule.chain[i] != sheet[0] or resi < sheet[1] or \
               resi > sheet[3]:
                continue
            secondary_structures[i] = ord('s')

        for j, _ in enumerate(molecule.helix):
            helix = molecule.helix[j]
            if molecule.chain[i] != helix[0] or resi < helix[1] or \
               resi > helix[3]:
                continue
            secondary_structures[i] = ord('h')

    output = PolyData()

    # for atomic numbers
    atomic_num_arr = nps.numpy_to_vtk(num_array=all_atomic_numbers,
                                      deep=True,
                                      array_type=VTK_ID_TYPE)

    # setting the array name to atom_type as vtkProteinRibbonFilter requires
    # the array to be named atom_type
    atomic_num_arr.SetName("atom_type")

    output.GetPointData().AddArray(atomic_num_arr)

    # for atom names
    atom_names = StringArray()

    # setting the array name to atom_types as vtkProteinRibbonFilter requires
    # the array to be named atom_types
    atom_names.SetName("atom_types")
    atom_names.SetNumberOfTuples(num_total_atoms)
    for i in range(num_total_atoms):
        atom_names.SetValue(i, molecule.atom_names[i])

    output.GetPointData().AddArray(atom_names)

    # for residue sequences
    residue_seq = nps.numpy_to_vtk(num_array=molecule.residue_seq,
                                   deep=True,
                                   array_type=VTK_ID_TYPE)
    residue_seq.SetName("residue")
    output.GetPointData().AddArray(residue_seq)

    # for chain
    chain = nps.numpy_to_vtk(num_array=molecule.chain,
                             deep=True,
                             array_type=VTK_UNSIGNED_CHAR)
    chain.SetName("chain")
    output.GetPointData().AddArray(chain)

    # for secondary structures
    s_s = nps.numpy_to_vtk(num_array=secondary_structures,
                           deep=True,
                           array_type=VTK_UNSIGNED_CHAR)
    s_s.SetName("secondary_structures")
    output.GetPointData().AddArray(s_s)

    # for secondary structures begin
    newarr = np.ones(num_total_atoms)
    s_sb = nps.numpy_to_vtk(num_array=newarr,
                            deep=True,
                            array_type=VTK_UNSIGNED_CHAR)
    s_sb.SetName("secondary_structures_begin")
    output.GetPointData().AddArray(s_sb)

    # for secondary structures end
    newarr = np.ones(num_total_atoms)
    s_se = nps.numpy_to_vtk(num_array=newarr,
                            deep=True,
                            array_type=VTK_UNSIGNED_CHAR)
    s_se.SetName("secondary_structures_end")
    output.GetPointData().AddArray(s_se)

    # for is_hetatm
    is_hetatm = nps.numpy_to_vtk(num_array=molecule.is_hetatm,
                                 deep=True,
                                 array_type=VTK_UNSIGNED_CHAR)
    is_hetatm.SetName("ishetatm")
    output.GetPointData().AddArray(is_hetatm)

    # for model
    model = nps.numpy_to_vtk(num_array=molecule.model,
                             deep=True,
                             array_type=VTK_UNSIGNED_INT)
    model.SetName("model")
    output.GetPointData().AddArray(model)

    table = PTable()

    # for colors and radii of hetero-atoms
    radii = np.ones((num_total_atoms, 3))
    rgb = np.ones((num_total_atoms, 3))

    for i in range(num_total_atoms):
        radii[i] = np.repeat(table.atomic_radius(all_atomic_numbers[i], 'VDW'),
                             3)
        rgb[i] = table.atom_color(all_atomic_numbers[i])

    Rgb = nps.numpy_to_vtk(num_array=rgb,
                           deep=True,
                           array_type=VTK_UNSIGNED_CHAR)
    Rgb.SetName("rgb_colors")
    output.GetPointData().SetScalars(Rgb)

    Radii = nps.numpy_to_vtk(num_array=radii, deep=True, array_type=VTK_FLOAT)
    Radii.SetName("radius")
    output.GetPointData().SetVectors(Radii)

    # setting the coordinates
    points = numpy_to_vtk_points(coords)
    output.SetPoints(points)

    ribbonFilter = ProteinRibbonFilter()
    ribbonFilter.SetInputData(output)
    ribbonFilter.SetCoilWidth(0.2)
    ribbonFilter.SetDrawSmallMoleculesAsSpheres(0)
    mapper = PolyDataMapper()
    mapper.SetInputConnection(ribbonFilter.GetOutputPort())
    molecule_actor = Actor()
    molecule_actor.SetMapper(mapper)
    return molecule_actor