class CellDerivatives(FilterBase):
    """Computes derivatives from input point scalar and vector data and
    produces cell data on the gradients.  Can be used to approximately
    calcuate the vorticity for example.
    """

    # The version of this class.  Used for persistence.
    __version__ = 0

    # The actual TVTK filter that this class manages.
    filter = Instance(tvtk.CellDerivatives,
                      args=(),
                      allow_none=False,
                      record=True)

    # Information about what this object can consume.
    input_info = PipelineInfo(datasets=['any'],
                              attribute_types=['any'],
                              attributes=['any'])

    # Information about what this object can produce.
    output_info = PipelineInfo(datasets=['any'],
                               attribute_types=['any'],
                               attributes=['any'])

    def has_output_port(self):
        """ The filter has an output port."""
        return True

    def get_output_object(self):
        """ Return the output port."""
        return self.filter.output_port
class CellToPointData(FilterBase):
    """Transforms cell attribute data to point data by averaging the
    cell data from the cells at the point.
    """

    # The version of this class.  Used for persistence.
    __version__ = 0

    # The actual TVTK filter that this class manages.
    filter = Instance(tvtk.CellDataToPointData,
                      args=(),
                      kw={'pass_cell_data': 1},
                      allow_none=False,
                      record=True)

    # Information about what this object can consume/produce.
    input_info = PipelineInfo(datasets=['any'],
                              attribute_types=['cell'],
                              attributes=['any'])

    output_info = PipelineInfo(datasets=['any'],
                               attribute_types=['any'],
                               attributes=['any'])

    def has_output_port(self):
        """ The filter has an output port."""
        return True

    def get_output_object(self):
        """ Return the output port."""
        return self.filter.output_port
class ExtractVectorComponents(FilterBase):
    """ This wraps the TVTK ExtractVectorComponents filter and allows
    one to select any of the three components of an input vector data
    attribute."""

    # The version of this class.  Used for persistence.
    __version__ = 0

    # The actual TVTK filter that this class manages.
    filter = Instance(tvtk.ExtractVectorComponents, args=(), allow_none=False)

    # The Vector Component to be extracted
    component = Enum('x-component',
                     'y-component',
                     'z-component',
                     desc='component of the vector to be extracted')

    input_info = PipelineInfo(datasets=['any'],
                              attribute_types=['any'],
                              attributes=['vectors'])

    output_info = PipelineInfo(datasets=['any'],
                               attribute_types=['any'],
                               attributes=['any'])

    view = View(Group(Item(name='component')), resizable=True)

    ######################################################################
    # `Filter` interface.
    ######################################################################
    def update_pipeline(self):
        # Do nothing if there is no input.
        inputs = self.inputs
        if len(inputs) == 0:
            return

        fil = self.filter
        fil.input = inputs[0].outputs[0]
        fil.update()
        self._component_changed(self.component)

    ######################################################################
    # Non-public interface.
    ######################################################################
    def _component_changed(self, value):
        # Obtain output from the TVTK ExtractVectorComponents filter
        # corresponding to the selected vector component

        if len(self.inputs) == 0:
            return

        if value == 'x-component':
            self._set_outputs([self.filter.vx_component])
        elif value == 'y-component':
            self._set_outputs([self.filter.vy_component])
        elif value == 'z-component':
            self._set_outputs([self.filter.vz_component])
        self.render()
class CellToPointData(FilterBase):
    """Transforms cell attribute data to point data by averaging the
    cell data from the cells at the point.
    """

    # The version of this class.  Used for persistence.
    __version__ = 0

    # The actual TVTK filter that this class manages.
    filter = Instance(tvtk.CellDataToPointData,
                      args=(),
                      kw={'pass_cell_data': 1},
                      allow_none=False,
                      record=True)

    # Information about what this object can consume/produce.
    input_info = PipelineInfo(datasets=['any'],
                              attribute_types=['cell'],
                              attributes=['any'])

    output_info = PipelineInfo(datasets=['any'],
                               attribute_types=['any'],
                               attributes=['any'])

    def update_pipeline(self):
        # Do nothing if there is no input.
        inputs = self.inputs
        if len(inputs) == 0:
            return

        fil = self.filter
        input = inputs[0].outputs[0]
        self.configure_connection(fil, inputs[0])
        fil.update()
        dataset = self.inputs[0].get_output_dataset()
        # This filter creates different outputs depending on the
        # input.
        out_map = {
            'vtkStructuredGrid': 'structured_grid_output',
            'vtkRectilinearGrid': 'rectilinear_grid_output',
            'vtkStructuredPoints': 'structured_points_output',
            'vtkUnstructuredGrid': 'unstructured_grid_output',
            'vtkPolyData': 'poly_data_output',
            'vtkImageData': 'image_data_output'
        }
        # Find the input data type and pass that to our output..
        for type in out_map:
            if dataset.is_a(type):
                self._set_outputs([getattr(fil, out_map[type])])
                break
    def test_no_valid_reader(self):
        """Test that if there is no reader which can read the file with
        assurity, the registry returns the last one of the readers
        which dont have a can_read_test and claim to read the file with
        the given extension"""
        open_dummy = SourceMetadata(
            id="DummyFile",
            class_name="mayavi.tests.test_registry.DummyReader",
            menu_name="&PLOT3D file",
            tooltip="Open a PLOT3D data data",
            desc="Open a PLOT3D data data",
            help="Open a PLOT3D data data",
            extensions=['xyz'],
            wildcard='PLOT3D files (*.xyz)|*.xyz',
            can_read_test='mayavi.tests.test_registry:DummyReader.check_read',
            output_info=PipelineInfo(datasets=['structured_grid'],
                                     attribute_types=['any'],
                                     attributes=['any']))
        registry.sources.append(open_dummy)

        # Remove the poly data reader.
        for index, src in enumerate(registry.sources[:]):
            if src.id == 'PolyDataFile':
                poly = src
                registry.sources.remove(src)
                break

        reader = registry.get_file_reader(get_example_data('tiny.xyz'))
        callable = reader.get_callable()
        self.assertEqual(callable.__name__, 'PLOT3DReader')

        # Add back the poly data reader.
        registry.sources.insert(index, poly)
        registry.sources.remove(open_dummy)
class DecimatePro(FilterBase):
    """ Reduces the number of triangles in a mesh using the
        tvtk.DecimatePro class. """

    # The version of this class.  Used for persistence.
    __version__ = 0

    # The actual TVTK filter that this class manages.
    filter = Instance(tvtk.DecimatePro, args=(), allow_none=False, record=True)

    input_info = PipelineInfo(datasets=['poly_data'],
                              attribute_types=['any'],
                              attributes=['any'])

    output_info = PipelineInfo(datasets=['poly_data'],
                               attribute_types=['any'],
                               attributes=['any'])
Exemple #7
0
class VolumeReader(Source):
    """A Volume reader.
    """

    # The version of this class.  Used for persistence.
    __version__ = 0

    file_prefix = Str('', desc='File prefix for the volume files')

    # The VTK data file reader.
    reader = Instance(tvtk.Volume16Reader,
                      args=(),
                      allow_none=False,
                      record=True)

    # Information about what this object can produce.
    output_info = PipelineInfo(datasets=['image_data'])

    ########################################
    # View related code.
    # Our view.
    view = View(Group(Item(name='reader', style='custom', resizable=True),
                      show_labels=False),
                resizable=True)

    ######################################################################
    # `Source` interface
    ######################################################################

    def __init__(self, file_prefix='', configure=True, **traits):
        super(VolumeReader, self).__init__(**traits)
        if configure:
            self.reader.edit_traits(kind='livemodal')

        self.file_prefix = self.reader.file_prefix

    def update(self):
        if len(self.file_prefix) == 0:
            return
        self.reader.update()
        self.render()

    ######################################################################
    # Non-public interface
    ######################################################################
    def _file_prefix_changed(self, value):
        if len(value) == 0:
            return
        else:
            self.reader.file_prefix = value
            self._update_reader_output()

    def _update_reader_output(self):
        self.reader.update()
        self.reader.update_information()
        self.reader.on_trait_change(self.render)
        self.outputs = [self.reader.output]
        self.data_changed = True
Exemple #8
0
class PointToCellData(CellToPointData):
    # The version of this class.  Used for persistence.
    __version__ = 0

    # The actual TVTK filter that this class manages.
    filter = Instance(tvtk.PointDataToCellData,
                      args=(),
                      kw={'pass_point_data': 1},
                      allow_none=False,
                      record=True)

    input_info = PipelineInfo(datasets=['any'],
                              attribute_types=['point'],
                              attributes=['any'])

    output_info = PipelineInfo(datasets=['any'],
                               attribute_types=['cell'],
                               attributes=['any'])
class GaussianSplatter(FilterBase):

    # The version of this class.  Used for persistence.
    __version__ = 0

    # The actual TVTK filter that this class manages.
    filter = Instance(tvtk.GaussianSplatter,
                      args=(),
                      allow_none=False,
                      record=True)

    input_info = PipelineInfo(datasets=['any'],
                              attribute_types=['any'],
                              attributes=['any'])

    output_info = PipelineInfo(datasets=['image_data'],
                               attribute_types=['any'],
                               attributes=['any'])
class Contour(Wrapper):
    """
    A contour filter that wraps around the Contour component to generate
    iso-surfaces on any input dataset.
    """
    # The version of this class.  Used for persistence.
    __version__ = 0

    # The contour component this wraps.
    filter = Instance(ContourComponent, args=(), record=True)

    input_info = PipelineInfo(datasets=['any'],
                              attribute_types=['point'],
                              attributes=['any'])

    output_info = PipelineInfo(datasets=['any'],
                               attribute_types=['any'],
                               attributes=['any'])
class Tube(FilterBase):

    """Turns lines into tubes.
    """

    # The version of this class.  Used for persistence.
    __version__ = 0

    # The actual TVTK filter that this class manages.
    filter = Instance(tvtk.TubeFilter, args=(), allow_none=False, record=True)

    input_info = PipelineInfo(datasets=['poly_data'],
                              attribute_types=['any'],
                              attributes=['any'])

    output_info = PipelineInfo(datasets=['poly_data'],
                               attribute_types=['any'],
                               attributes=['any'])
Exemple #12
0
class CubeAxesActor2D(tvtk.CubeAxesActor2D):
    """ Just has a different view than the tvtk.CubesAxesActor2D, with an
        additional tick box.
    """

    # Automaticaly fit the bounds of the axes to the data
    use_data_bounds = true

    input_info = PipelineInfo(datasets=['any'],
                              attribute_types=['any'],
                              attributes=['any'])

    ########################################
    # The view of this object.

    traits_view = View(Group(
                        Group(
                            Item('visibility'),
                            HGroup(
                                 Item('x_axis_visibility', label='X axis'),
                                 Item('y_axis_visibility', label='Y axis'),
                                 Item('z_axis_visibility', label='Z axis'),
                                ),
                            show_border=True, label='Visibity'),
                        Group(
                            Item('use_ranges'),
                            HGroup(
                                 Item('ranges', enabled_when='use_ranges'),
                                ),
                            show_border=True),
                        Group(
                            Item('use_data_bounds'),
                            HGroup(
                                 Item('bounds',
                                    enabled_when='not use_data_bounds'),
                                ),
                            show_border=True),
                        Group(
                            Item('x_label'),
                            Item('y_label'),
                            Item('z_label'),
                            Item('label_format'),
                            Item('number_of_labels'),
                            Item('font_factor'),
                            show_border=True),
                        HGroup(Item('show_actual_bounds',
                                label='Use size bigger than screen',
                                editor=BooleanEditor())),
                        Item('fly_mode'),
                        Item('corner_offset'),
                        Item('layer_number'),
                       springy=True,
                      ),
                     scrollable=True,
                     resizable=True,
                     )
class Stripper(FilterBase):
    """ Create triangle strips and/or poly-lines. Useful for regularizing
        broken up surfaces, such as those created by the Tube
        filter.
    """

    # The version of this class.  Used for persistence.
    __version__ = 0

    # The actual TVTK filter that this class manages.
    filter = Instance(tvtk.Stripper, args=(), allow_none=False, record=True)

    input_info = PipelineInfo(datasets=['poly_data'],
                              attribute_types=['any'],
                              attributes=['any'])

    output_info = PipelineInfo(datasets=['poly_data'],
                               attribute_types=['any'],
                               attributes=['any'])
class ElevationFilter(FilterBase):
    """ Generate scalar data from the elevation in a given direction """

    # The version of this class.  Used for persistence.
    __version__ = 0

    # The actual TVTK filter that this class manages.
    filter = Instance(tvtk.ElevationFilter,
                      args=(),
                      allow_none=False,
                      record=True)

    input_info = PipelineInfo(datasets=['any'],
                              attribute_types=['any'],
                              attributes=['any'])

    output_info = PipelineInfo(datasets=['any'],
                               attribute_types=['any'],
                               attributes=['any'])
class ExtractUnstructuredGrid(FilterBase):
    """Allows a user to select a part of an unstructured grid.
    """
    # The version of this class.  Used for persistence.
    __version__ = 0

    # The actual TVTK filter that this class manages.
    filter = Instance(tvtk.ExtractUnstructuredGrid,
                      args=(),
                      allow_none=False,
                      record=True)

    input_info = PipelineInfo(datasets=['unstructured_grid'],
                              attribute_types=['any'],
                              attributes=['any'])

    output_info = PipelineInfo(datasets=['unstructured_grid'],
                               attribute_types=['any'],
                               attributes=['any'])
Exemple #16
0
class GreedyTerrainDecimation(FilterBase):
    """ Performs a triangulation of image data after simplifying it. """

    # The version of this class.  Used for persistence.
    __version__ = 0

    # The actual TVTK filter that this class manages.
    filter = Instance(tvtk.GreedyTerrainDecimation,
                      args=(),
                      allow_none=False,
                      record=True)

    input_info = PipelineInfo(datasets=['any'],
                              attribute_types=['any'],
                              attributes=['any'])

    output_info = PipelineInfo(datasets=['poly_data'],
                               attribute_types=['any'],
                               attributes=['any'])
Exemple #17
0
class Delaunay2D(FilterBase):
    """Performs a 2D Delaunay triangulation using the tvtk.Delaunay2D
    class.
    """

    # The version of this class.  Used for persistence.
    __version__ = 0

    # The actual TVTK filter that this class manages.
    filter = Instance(tvtk.Delaunay2D, args=(), allow_none=False, record=True)

    input_info = PipelineInfo(
        datasets=['structured_grid', 'poly_data', 'unstructured_grid'],
        attribute_types=['any'],
        attributes=['any'])

    output_info = PipelineInfo(datasets=['poly_data'],
                               attribute_types=['any'],
                               attributes=['any'])
Exemple #18
0
class WarpVector(PolyDataNormals):
    """Warps the input data along a the point vector attribute scaled
    as per a scale factor.  Useful for showing flow profiles or
    displacements.
    """

    # The version of this class.  Used for persistence.
    __version__ = 0

    # The actual TVTK filter that this class manages.
    filter = Instance(tvtk.WarpVector, args=(), allow_none=False, record=True)

    input_info = PipelineInfo(datasets=['any'],
                              attribute_types=['any'],
                              attributes=['vectors'])

    output_info = PipelineInfo(datasets=['any'],
                               attribute_types=['any'],
                               attributes=['any'])
Exemple #19
0
class QuadricDecimation(PolyDataFilterBase):
    """ Simplifies triangles of a mesh """

    # The version of this class.  Used for persistence.
    __version__ = 0

    # The actual TVTK filter that this class manages.
    filter = Instance(tvtk.QuadricDecimation,
                      args=(),
                      allow_none=False,
                      record=True)

    input_info = PipelineInfo(datasets=['poly_data'],
                              attribute_types=['any'],
                              attributes=['any'])

    output_info = PipelineInfo(datasets=['poly_data'],
                               attribute_types=['any'],
                               attributes=['any'])
Exemple #20
0
class MaskPoints(FilterBase):
    """Selectively passes the input points downstream.  This can be
    used to subsample the input points.  Note that this does not pass
    geometry data, this means all grid information is lost.
    """

    # The version of this class.  Used for persistence.
    __version__ = 0

    # The actual TVTK filter that this class manages.
    filter = Instance(tvtk.MaskPoints, args=(), allow_none=False, record=True)

    input_info = PipelineInfo(datasets=['any'],
                              attribute_types=['any'],
                              attributes=['any'])

    output_info = PipelineInfo(datasets=['poly_data'],
                               attribute_types=['any'],
                               attributes=['any'])
Exemple #21
0
class TriangleFilter(FilterBase):

    """ Converts input polygons and triangle strips to triangles using
    the tvtk.TriangleFilter class.  This is useful when you have a
    downstream filter that only processes triangles."""

    # The version of this class.  Used for persistence.
    __version__ = 0

    # The actual TVTK filter that this class manages.
    filter = Instance(tvtk.TriangleFilter, args=(), allow_none=False, record=True)

    input_info = PipelineInfo(datasets=['any'],
                              attribute_types=['any'],
                              attributes=['any'])

    output_info = PipelineInfo(datasets=['poly_data',
                                         'unstructured_grid'],
                               attribute_types=['any'],
                               attributes=['any'])
Exemple #22
0
class WarpScalar(PolyDataNormals):

    """Warps the input data along a particular direction (either the
    normals or a specified direction) with a scale specified by the
    local scalar value.  Useful for making carpet plots.
    """

    # The version of this class.  Used for persistence.
    __version__ = 0

    # The actual TVTK filter that this class manages.
    filter = Instance(tvtk.WarpScalar, args=(), allow_none=False, record=True)

    input_info = PipelineInfo(datasets=['any'],
                              attribute_types=['any'],
                              attributes=['scalars'])

    output_info = PipelineInfo(datasets=['any'],
                               attribute_types=['any'],
                               attributes=['any'])
Exemple #23
0
class MaskPoints(FilterBase):

    """Selectively passes the input points downstream.  This can be
    used to subsample the input points.  Note that this does not pass
    geometry data, this means all grid information is lost.
    """

    # The version of this class.  Used for persistence.
    __version__ = 0

    # The actual TVTK filter that this class manages.
    filter = Instance(tvtk.MaskPoints, args=(), allow_none=False, record=True)

    input_info = PipelineInfo(datasets=['any'],
                              attribute_types=['any'],
                              attributes=['any'])

    output_info = PipelineInfo(datasets=['poly_data'],
                               attribute_types=['any'],
                               attributes=['any'])

    ######################################################################
    # `Filter` interface.
    ######################################################################
    def update_pipeline(self):
        # FIXME: This is needed, for with VTK-5.10 (for sure), the filter
        # allocates memory for maximum_number_of_points which is impossibly
        # large,  so we set it to the number of points in the input
        # for safety.
        self.filter.maximum_number_of_points = \
            self._find_number_of_points_in_input()
        super(MaskPoints, self).update_pipeline()

    ######################################################################
    # Non-public interface.
    ######################################################################
    def _find_number_of_points_in_input(self):
        inp = self.inputs[0].outputs[0]
        if hasattr(inp, 'update'):
            inp.update()
        return inp.number_of_points
Exemple #24
0
class ChacoReader(Source):
    """A Chaco reader.
    """

    # The version of this class.  Used for persistence.
    __version__ = 0

    base_name = Str('', desc='basename of the Chaco files')

    # The VTK data file reader.
    reader = Instance(tvtk.ChacoReader, args=(), allow_none=False, record=True)

    # Information about what this object can produce.
    output_info = PipelineInfo(datasets=['unstructured_grid'])

    ########################################
    # View related code.
    # Our view.
    view = View(Group(Item(name='reader', style='custom', resizable=True),
                      show_labels=False),
                resizable=True)

    ######################################################################
    # `FileDataSource` interface
    ######################################################################
    def __init__(self, base_name='', configure=True, **traits):
        super(ChacoReader, self).__init__(**traits)
        if configure:
            self.reader.edit_traits(kind='livemodal')
        self.base_name = self.reader.base_name

    def update(self):
        if len(self.base_name) == 0:
            return
        self.reader.update()
        self.render()

    ######################################################################
    # Non-public interface
    ######################################################################
    def _base_name_changed(self, value):
        if len(value) == 0:
            return
        else:
            self.reader.base_name = value
            self._update_reader_output()

    def _update_reader_output(self):
        self.reader.update()
        self.reader.update_information()
        self.reader.on_trait_change(self.render)
        self.outputs = [self.reader.output]
        self.data_changed = True
class PolyDataNormals(PolyDataFilterBase):
    """Computes normals from input data.  This gives meshes a smoother
    appearance.  This should work for any input dataset.
    """

    # The version of this class.  Used for persistence.
    __version__ = 0

    # The actual TVTK filter that this class manages.
    filter = Instance(tvtk.PolyDataNormals,
                      args=(),
                      allow_none=False,
                      record=True)

    input_info = PipelineInfo(datasets=['poly_data'],
                              attribute_types=['any'],
                              attributes=['any'])

    output_info = PipelineInfo(datasets=['poly_data'],
                               attribute_types=['any'],
                               attributes=['any'])
Exemple #26
0
class ImageChangeInformation(FilterBase):
    """
    A filter that lets you change the spacing and origin of an input
    ImageData dataset.
    """

    # The version of this class.  Used for persistence.
    __version__ = 0

    # The actual TVTK filter that this class manages.
    filter = Instance(tvtk.ImageChangeInformation,
                      args=(),
                      allow_none=False,
                      record=True)

    input_info = PipelineInfo(datasets=['image_data'],
                              attribute_types=['any'],
                              attributes=['any'])

    output_info = PipelineInfo(datasets=['image_data'],
                               attribute_types=['any'],
                               attributes=['any'])
class CutPlane(Collection):
    """
    This class represents a cut plane that can be used to slice through
    any dataset.  It also provides a 3D widget interface to position and
    move the slice interactively.
    """
    # The version of this class.  Used for persistence.
    __version__ = 0

    input_info = PipelineInfo(datasets=['any'],
                              attribute_types=['any'],
                              attributes=['any'])
    output_info = PipelineInfo(datasets=['poly_data'],
                               attribute_types=['any'],
                               attributes=['any'])

    ######################################################################
    # `Filter` interface.
    ######################################################################
    def setup_pipeline(self):
        """Creates the pipeline."""
        ip = ImplicitPlane()
        cut = Cutter(cut_function=ip.plane)
        self.filters = [ip, cut]
Exemple #28
0
class PointLoad(Source):
    # The version of this class.  Used for persistence.
    __version__ = 0

    point_load = Instance(tvtk.PointLoad,
                          args=(),
                          allow_none=False,
                          record=True)

    # Information about what this object can produce.
    output_info = PipelineInfo(datasets=['image_data'],
                               attribute_types=['any'],
                               attributes=['any'])

    # Create the UI for the traits.
    view = View(Group(Item(name='point_load', style='custom', resizable=True),
                      label='PointLoad',
                      show_labels=False),
                resizable=True)

    ######################################################################
    # `object` interface
    ######################################################################
    def __init__(self, **traits):
        # Call parent class' init.
        super(PointLoad, self).__init__(**traits)

        # Call render everytime source traits change.
        self.point_load.on_trait_change(self.render)

        # Setup the outputs.
        self.outputs = [self.point_load.output]

    def has_output_port(self):
        """ Return True as the point load has output port."""
        return True

    def get_output_object(self):
        """ Return the point load output port."""
        return self.point_load.output_port
    def test_multiple_valid_readers(self):
        """Test if the fixture works fine if there are multiple readers
        capable of reading the file properly"""
        # Inserting a dummy reader into the registry also capable of
        # reading files with extension 'xyz'
        open_dummy = SourceMetadata(
                id            = "DummyFile",
                class_name    = "mayavi.tests.test_registry.DummyReader",
                menu_name     = "&PLOT3D file",
                tooltip       = "Open a PLOT3D data data",
                desc        = "Open a PLOT3D data data",
                help        = "Open a PLOT3D data data",
                extensions = ['xyz'],
                wildcard = 'PLOT3D files (*.xyz)|*.xyz',
                can_read_test = 'mayavi.tests.test_registry:DummyReader.check_read',
                output_info = PipelineInfo(datasets=['structured_grid'],
                    attribute_types=['any'],
                    attributes=['any'])
                )
        registry.sources.append(open_dummy)
        reader = registry.get_file_reader(get_example_data('tiny.xyz'))
        callable = reader.get_callable()
        self.assertEqual(callable.__name__, 'PLOT3DReader')

        # Removing existing readers for .xyz extensions to check if the Dummy
        # reader now reads it.
        remove = []
        for index, src in enumerate(registry.sources[:]):
            if 'xyz' in src.extensions and src.id != 'DummyFile':
                remove.append((index, src))
                registry.sources.remove(src)

        reader = registry.get_file_reader(get_example_data('tiny.xyz'))
        callable = reader.get_callable()
        self.assertEqual(callable.__name__, 'DummyReader')

        for index, src in remove:
            registry.sources.insert(index, src)
        registry.sources.remove(open_dummy)
class AddModuleManager(ModuleAction):
    """ An action that adds a ModuleManager to the tree. """

    tooltip = "Add a ModuleManager to the current source/filter"

    description = "Add a ModuleManager to the current source/filter"

    metadata = ModuleMetadata(
        id="AddModuleManager",
        class_name="mayavi.core.module_manager.ModuleManager",
        menu_name="&Add ModuleManager",
        tooltip="Add a ModuleManager to the current source/filter",
        description="Add a ModuleManager to the current source/filter",
        input_info=PipelineInfo(datasets=['any'],
                                attribute_types=['any'],
                                attributes=['any']))

    def perform(self, event):
        """ Performs the action. """
        from mayavi.core.module_manager import ModuleManager
        mm = ModuleManager()
        mv = self.mayavi
        mv.add_module(mm)
        mv.engine.current_selection = mm