Example #1
0
class Volume3D(Volume):
    """ Subclass to provide access to VolumeTextureMapper3D.
    """

    volume_mapper_type = DEnum(values_name='_mapper_types',
                               value='VolumeTextureMapper3D',
                               desc='volume mapper to use')

    def _update_ctf_fired(self):
        self.render()
Example #2
0
class Volume(Module):
    """The Volume module visualizes scalar fields using volumetric
    visualization techniques.  This supports ImageData and
    UnstructuredGrid data.  It also supports the FixedPointRenderer
    for ImageData.  However, the performance is slow so your best bet
    is probably with the ImageData based renderers.
    """

    # The version of this class.  Used for persistence.
    __version__ = 0

    volume_mapper_type = DEnum(values_name='_mapper_types',
                               desc='volume mapper to use')

    ray_cast_function_type = DEnum(values_name='_ray_cast_functions',
                                   desc='Ray cast function to use')

    volume = ReadOnly

    volume_mapper = Property(record=True)

    volume_property = Property(record=True)

    ray_cast_function = Property(record=True)

    lut_manager = Instance(VolumeLUTManager,
                           args=(),
                           allow_none=False,
                           record=True)

    input_info = PipelineInfo(datasets=['image_data', 'unstructured_grid'],
                              attribute_types=['any'],
                              attributes=['scalars'])

    ########################################
    # View related code.

    update_ctf = Button('Update CTF')

    view = View(Group(Item(name='_volume_property',
                           style='custom',
                           editor=VolumePropertyEditor,
                           resizable=True),
                      Item(name='update_ctf'),
                      label='CTF',
                      show_labels=False),
                Group(
                    Item(name='volume_mapper_type'),
                    Group(Item(name='_volume_mapper',
                               style='custom',
                               resizable=True),
                          show_labels=False),
                    Item(name='ray_cast_function_type'),
                    Group(Item(name='_ray_cast_function',
                               enabled_when='len(_ray_cast_functions) > 0',
                               style='custom',
                               resizable=True),
                          show_labels=False),
                    label='Mapper',
                ),
                Group(Item(name='_volume_property',
                           style='custom',
                           resizable=True),
                      label='Property',
                      show_labels=False),
                Group(Item(name='volume',
                           style='custom',
                           editor=InstanceEditor(),
                           resizable=True),
                      label='Volume',
                      show_labels=False),
                Group(Item(name='lut_manager', style='custom', resizable=True),
                      label='Legend',
                      show_labels=False),
                resizable=True)

    ########################################
    # Private traits
    _volume_mapper = Instance(tvtk.AbstractVolumeMapper)
    _volume_property = Instance(tvtk.VolumeProperty)
    _ray_cast_function = Instance(tvtk.Object)

    _mapper_types = List(Str, [
        'TextureMapper2D',
        'RayCastMapper',
    ])

    _available_mapper_types = List(Str)

    _ray_cast_functions = List(Str)

    current_range = Tuple

    # The color transfer function.
    _ctf = Instance(ColorTransferFunction)
    # The opacity values.
    _otf = Instance(PiecewiseFunction)

    ######################################################################
    # `object` interface
    ######################################################################
    def __get_pure_state__(self):
        d = super(Volume, self).__get_pure_state__()
        d['ctf_state'] = save_ctfs(self._volume_property)
        for name in ('current_range', '_ctf', '_otf'):
            d.pop(name, None)
        return d

    def __set_pure_state__(self, state):
        self.volume_mapper_type = state['_volume_mapper_type']
        state_pickler.set_state(self, state, ignore=['ctf_state'])
        ctf_state = state['ctf_state']
        ctf, otf = load_ctfs(ctf_state, self._volume_property)
        self._ctf = ctf
        self._otf = otf
        self._update_ctf_fired()

    ######################################################################
    # `Module` interface
    ######################################################################
    def start(self):
        super(Volume, self).start()
        self.lut_manager.start()

    def stop(self):
        super(Volume, self).stop()
        self.lut_manager.stop()

    def setup_pipeline(self):
        """Override this method so that it *creates* the tvtk
        pipeline.
        """
        v = self.volume = tvtk.Volume()
        vp = self._volume_property = tvtk.VolumeProperty()

        self._ctf = ctf = default_CTF(0, 255)
        self._otf = otf = default_OTF(0, 255)
        vp.set_color(ctf)
        vp.set_scalar_opacity(otf)
        vp.shade = True
        vp.interpolation_type = 'linear'
        v.property = vp

        v.on_trait_change(self.render)
        vp.on_trait_change(self.render)

        available_mappers = find_volume_mappers()
        if is_volume_pro_available():
            self._mapper_types.append('VolumeProMapper')
            available_mappers.append('VolumeProMapper')

        self._available_mapper_types = available_mappers
        if 'FixedPointVolumeRayCastMapper' in available_mappers:
            self._mapper_types.append('FixedPointVolumeRayCastMapper')

        self.actors.append(v)

    def update_pipeline(self):
        """Override this method so that it *updates* the tvtk pipeline
        when data upstream is known to have changed.

        This method is invoked (automatically) when any of the inputs
        sends a `pipeline_changed` event.
        """
        mm = self.module_manager
        if mm is None:
            return

        dataset = mm.source.get_output_dataset()

        ug = hasattr(tvtk, 'UnstructuredGridVolumeMapper')
        if dataset.is_a('vtkMultiBlockDataSet'):
            if 'MultiBlockVolumeMapper' not in self._available_mapper_types:
                error('Your version of VTK does not support '
                      'Multiblock volume rendering')
                return
        elif dataset.is_a('vtkUniformGridAMR'):
            if 'AMRVolumeMapper' not in self._available_mapper_types:
                error('Your version of VTK does not support '
                      'AMR volume rendering')
                return
        elif ug:
            if not dataset.is_a('vtkImageData') \
                   and not dataset.is_a('vtkUnstructuredGrid'):
                error('Volume rendering only works with '
                      'StructuredPoints/ImageData/UnstructuredGrid datasets')
                return
        elif not dataset.is_a('vtkImageData'):
            error('Volume rendering only works with '
                  'StructuredPoints/ImageData datasets')
            return

        self._setup_mapper_types()
        self._setup_current_range()
        self._volume_mapper_type_changed(self.volume_mapper_type)
        self._update_ctf_fired()
        self.pipeline_changed = True

    def update_data(self):
        """Override this method so that it flushes the vtk pipeline if
        that is necessary.

        This method is invoked (automatically) when any of the inputs
        sends a `data_changed` event.
        """
        self._setup_mapper_types()
        self._setup_current_range()
        self._update_ctf_fired()
        self.data_changed = True

    ######################################################################
    # Non-public methods.
    ######################################################################
    def _get_image_data_volume_mappers(self):
        check = ('SmartVolumeMapper', 'GPUVolumeRayCastMapper',
                 'OpenGLGPUVolumeRayCastMapper')
        return [x for x in check if x in self._available_mapper_types]

    def _setup_mapper_types(self):
        """Sets up the mapper based on input data types.
        """
        dataset = self.module_manager.source.get_output_dataset()
        if dataset.is_a('vtkMultiBlockDataSet'):
            if 'MultiBlockVolumeMapper' in self._available_mapper_types:
                self._mapper_types = ['MultiBlockVolumeMapper']
        elif dataset.is_a('vtkUniformGridAMR'):
            if 'AMRVolumeMapper' not in self._available_mapper_types:
                self._mapper_types = ['AMRVolumeMapper']
        elif dataset.is_a('vtkUnstructuredGrid'):
            if hasattr(tvtk, 'UnstructuredGridVolumeMapper'):
                check = [
                    'UnstructuredGridVolumeZSweepMapper',
                    'UnstructuredGridVolumeRayCastMapper',
                ]
                mapper_types = []
                for mapper in check:
                    if mapper in self._available_mapper_types:
                        mapper_types.append(mapper)
                if len(mapper_types) == 0:
                    mapper_types = ['']
                self._mapper_types = mapper_types
                return
        else:
            mapper_types = self._get_image_data_volume_mappers()
            if dataset.point_data.scalars.data_type not in \
               [vtkConstants.VTK_UNSIGNED_CHAR,
                vtkConstants.VTK_UNSIGNED_SHORT]:
                if 'FixedPointVolumeRayCastMapper' \
                       in self._available_mapper_types:
                    mapper_types.append('FixedPointVolumeRayCastMapper')
                elif len(mapper_types) == 0:
                    error('Available volume mappers only work with '
                          'unsigned_char or unsigned_short datatypes')
            else:
                check = [
                    'FixedPointVolumeRayCastMapper', 'VolumeProMapper',
                    'TextureMapper2D', 'RayCastMapper', 'TextureMapper3D'
                ]
                for mapper in check:
                    if mapper in self._available_mapper_types:
                        mapper_types.append(mapper)
            self._mapper_types = mapper_types

    def _setup_current_range(self):
        mm = self.module_manager
        # Set the default name and range for our lut.
        lm = self.lut_manager
        slm = mm.scalar_lut_manager
        lm.trait_set(default_data_name=slm.default_data_name,
                     default_data_range=slm.default_data_range)

        # Set the current range.
        dataset = mm.source.get_output_dataset()
        dsh = DataSetHelper(dataset)
        name, rng = dsh.get_range('scalars', 'point')
        if name is None:
            error('No scalars in input data!')
            rng = (0, 255)

        if self.current_range != rng:
            self.current_range = rng

    def _get_volume_mapper(self):
        return self._volume_mapper

    def _get_volume_property(self):
        return self._volume_property

    def _get_ray_cast_function(self):
        return self._ray_cast_function

    def _volume_mapper_type_changed(self, value):
        mm = self.module_manager
        if mm is None:
            return

        old_vm = self._volume_mapper
        if old_vm is not None:
            old_vm.on_trait_change(self.render, remove=True)

        if value == 'RayCastMapper':
            new_vm = tvtk.VolumeRayCastMapper()
            self._volume_mapper = new_vm
            self._ray_cast_functions = [
                'RayCastCompositeFunction', 'RayCastMIPFunction',
                'RayCastIsosurfaceFunction'
            ]
            new_vm.volume_ray_cast_function = tvtk.VolumeRayCastCompositeFunction(
            )
        elif value == 'MultiBlockVolumeMapper':
            new_vm = tvtk.MultiBlockVolumeMapper()
            self._volume_mapper = new_vm
            self._ray_cast_functions = ['']
        elif value == 'AMRVolumeMapper':
            new_vm = tvtk.AMRVolumeMapper()
            self._volume_mapper = new_vm
            self._ray_cast_functions = ['']
        elif value == 'SmartVolumeMapper':
            new_vm = tvtk.SmartVolumeMapper()
            self._volume_mapper = new_vm
            self._ray_cast_functions = ['']
        elif value == 'GPUVolumeRayCastMapper':
            new_vm = tvtk.GPUVolumeRayCastMapper()
            self._volume_mapper = new_vm
            self._ray_cast_functions = ['']
        elif value == 'OpenGLGPUVolumeRayCastMapper':
            new_vm = tvtk.OpenGLGPUVolumeRayCastMapper()
            self._volume_mapper = new_vm
            self._ray_cast_functions = ['']
        elif value == 'TextureMapper2D':
            new_vm = tvtk.VolumeTextureMapper2D()
            self._volume_mapper = new_vm
            self._ray_cast_functions = ['']
        elif value == 'TextureMapper3D':
            new_vm = tvtk.VolumeTextureMapper3D()
            self._volume_mapper = new_vm
            self._ray_cast_functions = ['']
        elif value == 'VolumeProMapper':
            new_vm = tvtk.VolumeProMapper()
            self._volume_mapper = new_vm
            self._ray_cast_functions = ['']
        elif value == 'FixedPointVolumeRayCastMapper':
            new_vm = tvtk.FixedPointVolumeRayCastMapper()
            self._volume_mapper = new_vm
            self._ray_cast_functions = ['']
        elif value == 'UnstructuredGridVolumeRayCastMapper':
            new_vm = tvtk.UnstructuredGridVolumeRayCastMapper()
            self._volume_mapper = new_vm
            self._ray_cast_functions = ['']
        elif value == 'UnstructuredGridVolumeZSweepMapper':
            new_vm = tvtk.UnstructuredGridVolumeZSweepMapper()
            self._volume_mapper = new_vm
            self._ray_cast_functions = ['']

        src = mm.source
        self.configure_connection(new_vm, src)
        self.volume.mapper = new_vm
        new_vm.on_trait_change(self.render)

    def _update_ctf_fired(self):
        set_lut(self.lut_manager.lut, self._volume_property)
        self.render()

    def _current_range_changed(self, old, new):
        rescale_ctfs(self._volume_property, new)
        self.render()

    def _ray_cast_function_type_changed(self, old, new):
        rcf = self.ray_cast_function
        if len(old) > 0:
            rcf.on_trait_change(self.render, remove=True)

        if len(new) > 0:
            new_rcf = getattr(tvtk, 'Volume%s' % new)()
            new_rcf.on_trait_change(self.render)
            self._volume_mapper.volume_ray_cast_function = new_rcf
            self._ray_cast_function = new_rcf
        else:
            self._ray_cast_function = None

        self.render()

    def _scene_changed(self, old, new):
        super(Volume, self)._scene_changed(old, new)
        self.lut_manager.scene = new
Example #3
0
class SetActiveAttribute(Filter):
    """
    This filter lets a user set the active data attribute (scalars,
    vectors and tensors) on a VTK dataset.  This is particularly useful
    if you need to do something like compute contours of one scalar on
    the contour of another scalar.
    """

    # Note: most of this code is from the XMLFileDataReader.

    # The version of this class.  Used for persistence.
    __version__ = 0

    input_info = PipelineInfo(datasets=['any'],
                              attribute_types=['any'],
                              attributes=['any'])

    output_info = PipelineInfo(datasets=['any'],
                               attribute_types=['any'],
                               attributes=['any'])

    ########################################
    # Dynamic traits: These traits are dynamic and are automatically
    # updated depending on the contents of the file.

    # The active point scalar name.  An empty string indicates that
    # the attribute is "deactivated".  This is useful when you have
    # both point and cell attributes and want to use cell data by
    # default.
    point_scalars_name = DEnum(values_name='_point_scalars_list',
                               desc='scalar point data attribute to use')
    # The active point vector name.
    point_vectors_name = DEnum(values_name='_point_vectors_list',
                               desc='vectors point data attribute to use')
    # The active point tensor name.
    point_tensors_name = DEnum(values_name='_point_tensors_list',
                               desc='tensor point data attribute to use')

    # The active cell scalar name.
    cell_scalars_name = DEnum(values_name='_cell_scalars_list',
                              desc='scalar cell data attribute to use')
    # The active cell vector name.
    cell_vectors_name = DEnum(values_name='_cell_vectors_list',
                              desc='vectors cell data attribute to use')
    # The active cell tensor name.
    cell_tensors_name = DEnum(values_name='_cell_tensors_list',
                              desc='tensor cell data attribute to use')
    ########################################

    # Our view.
    view = View(
        Group(
            Item(name='point_scalars_name'),
            Item(name='point_vectors_name'),
            Item(name='point_tensors_name'),
            Item(name='cell_scalars_name'),
            Item(name='cell_vectors_name'),
            Item(name='cell_tensors_name'),
        ))

    ########################################
    # Private traits.

    # These private traits store the list of available data
    # attributes.  The non-private traits use these lists internally.
    _point_scalars_list = List(Str)
    _point_vectors_list = List(Str)
    _point_tensors_list = List(Str)
    _cell_scalars_list = List(Str)
    _cell_vectors_list = List(Str)
    _cell_tensors_list = List(Str)

    # This filter allows us to change the attributes of the data
    # object and will ensure that the pipeline is properly taken care
    # of.  Directly setting the array in the VTK object will not do
    # this.
    _assign_attribute = Instance(tvtk.AssignAttribute,
                                 args=(),
                                 allow_none=False)

    # Toggles if this is the first time this object has been used.
    _first = Bool(True)

    ######################################################################
    # `object` interface
    ######################################################################
    def __get_pure_state__(self):
        d = super(SetActiveAttribute, self).__get_pure_state__()
        for name in ('_assign_attribute', '_first'):
            d.pop(name, None)
        # Pickle the 'point_scalars_name' etc. since these are
        # properties and not in __dict__.
        attr = {}
        for name in ('point_scalars', 'point_vectors', 'point_tensors',
                     'cell_scalars', 'cell_vectors', 'cell_tensors'):
            d.pop('_' + name + '_list', None)
            d.pop('_' + name + '_name', None)
            x = name + '_name'
            attr[x] = getattr(self, x)
        d.update(attr)

        return d

    ######################################################################
    # `Filter` interface.
    ######################################################################
    def update_data(self):
        self.data_changed = True

    def update_pipeline(self):
        if len(self.inputs) == 0 or len(self.inputs[0].outputs) == 0:
            return

        aa = self._assign_attribute
        aa.input = self.inputs[0].outputs[0]
        self._update()
        self._set_outputs([aa.output])

    ######################################################################
    # Non-public interface.
    ######################################################################
    def _update(self):
        """Updates the traits for the fields that are available in the
        input data.
        """
        if len(self.inputs) == 0 or len(self.inputs[0].outputs) == 0:
            return

        input = self.inputs[0].outputs[0]
        if self._first:
            # Force all attributes to be defined and computed
            input.update()
        pnt_attr, cell_attr = get_all_attributes(input)

        self._setup_data_traits(cell_attr, 'cell')
        self._setup_data_traits(pnt_attr, 'point')
        if self._first:
            self._first = False

    def _setup_data_traits(self, attributes, d_type):
        """Given the dict of the attributes from the
        `get_all_attributes` function and the data type (point/cell)
        data this will setup the object and the data.
        """
        attrs = ['scalars', 'vectors', 'tensors']
        aa = self._assign_attribute
        input = self.inputs[0].outputs[0]
        data = getattr(input, '%s_data' % d_type)
        for attr in attrs:
            values = attributes[attr]
            values.append('')
            setattr(self, '_%s_%s_list' % (d_type, attr), values)
            if len(values) > 1:
                default = getattr(self, '%s_%s_name' % (d_type, attr))
                if self._first and len(default) == 0:
                    default = values[0]
                getattr(data, 'set_active_%s' % attr)(default)
                aa.assign(default, attr.upper(), d_type.upper() + '_DATA')
                aa.update()
                kw = {
                    '%s_%s_name' % (d_type, attr): default,
                    'trait_change_notify': False
                }
                self.set(**kw)

    def _set_data_name(self, data_type, attr_type, value):
        if value is None or len(self.inputs) == 0:
            return

        input = self.inputs[0].outputs[0]
        if len(value) == 0:
            # If the value is empty then we deactivate that attribute.
            d = getattr(input, attr_type + '_data')
            method = getattr(d, 'set_active_%s' % data_type)
            method(None)
            self.data_changed = True
            return

        aa = self._assign_attribute
        data = None
        if attr_type == 'point':
            data = input.point_data
        elif attr_type == 'cell':
            data = input.cell_data

        method = getattr(data, 'set_active_%s' % data_type)
        method(value)
        aa.assign(value, data_type.upper(), attr_type.upper() + '_DATA')
        aa.update()
        # Fire an event, so the changes propagate.
        self.data_changed = True

    def _point_scalars_name_changed(self, value):
        self._set_data_name('scalars', 'point', value)

    def _point_vectors_name_changed(self, value):
        self._set_data_name('vectors', 'point', value)

    def _point_tensors_name_changed(self, value):
        self._set_data_name('tensors', 'point', value)

    def _cell_scalars_name_changed(self, value):
        self._set_data_name('scalars', 'cell', value)

    def _cell_vectors_name_changed(self, value):
        self._set_data_name('vectors', 'cell', value)

    def _cell_tensors_name_changed(self, value):
        self._set_data_name('tensors', 'cell', value)
class VTKDataSource(Source):
    """This source manages a VTK dataset given to it.  When this
    source is pickled or persisted, it saves the data given to it in
    the form of a gzipped string.

    Note that if the VTK dataset has changed internally and you need
    to notify the mayavi pipeline to flush the data just call the
    `modified` method of the VTK dataset and the mayavi pipeline will
    update automatically.
    """

    # The version of this class.  Used for persistence.
    __version__ = 0

    # The VTK dataset to manage.
    data = Instance(tvtk.DataSet, allow_none=False)

    # Information about what this object can produce.
    output_info = PipelineInfo(datasets=['any'],
                               attribute_types=['any'],
                               attributes=['any'])

    ########################################
    # Dynamic traits: These traits are dynamic and are updated on the
    # _update_data method.

    # The active point scalar name.
    point_scalars_name = DEnum(values_name='_point_scalars_list',
                               desc='scalar point data attribute to use')
    # The active point vector name.
    point_vectors_name = DEnum(values_name='_point_vectors_list',
                               desc='vectors point data attribute to use')
    # The active point tensor name.
    point_tensors_name = DEnum(values_name='_point_tensors_list',
                               desc='tensor point data attribute to use')

    # The active cell scalar name.
    cell_scalars_name = DEnum(values_name='_cell_scalars_list',
                              desc='scalar cell data attribute to use')
    # The active cell vector name.
    cell_vectors_name = DEnum(values_name='_cell_vectors_list',
                              desc='vectors cell data attribute to use')
    # The active cell tensor name.
    cell_tensors_name = DEnum(values_name='_cell_tensors_list',
                              desc='tensor cell data attribute to use')

    ########################################
    # Our view.

    view = View(
        Group(
            Item(name='point_scalars_name'),
            Item(name='point_vectors_name'),
            Item(name='point_tensors_name'),
            Item(name='cell_scalars_name'),
            Item(name='cell_vectors_name'),
            Item(name='cell_tensors_name'),
            Item(name='data'),
        ))

    ########################################
    # Private traits.

    # These private traits store the list of available data
    # attributes.  The non-private traits use these lists internally.
    _point_scalars_list = List(Str)
    _point_vectors_list = List(Str)
    _point_tensors_list = List(Str)
    _cell_scalars_list = List(Str)
    _cell_vectors_list = List(Str)
    _cell_tensors_list = List(Str)

    # This filter allows us to change the attributes of the data
    # object and will ensure that the pipeline is properly taken care
    # of.  Directly setting the array in the VTK object will not do
    # this.
    _assign_attribute = Instance(tvtk.AssignAttribute,
                                 args=(),
                                 allow_none=False)

    # Toggles if this is the first time this object has been used.
    _first = Bool(True)

    # The ID of the observer for the data.
    _observer_id = Int(-1)

    ######################################################################
    # `object` interface
    ######################################################################
    def __get_pure_state__(self):
        d = super(VTKDataSource, self).__get_pure_state__()
        for name in ('_assign_attribute', '_first', '_observer'):
            d.pop(name, None)
        for name in ('point_scalars', 'point_vectors', 'point_tensors',
                     'cell_scalars', 'cell_vectors', 'cell_tensors'):
            d.pop('_' + name + '_list', None)
            d.pop('_' + name + '_name', None)
        data = self.data
        if data is not None:
            sdata = write_dataset_to_string(data)
            z = gzip_string(sdata)
            d['data'] = z
        return d

    def __set_pure_state__(self, state):
        z = state.data
        if z is not None:
            d = gunzip_string(z)
            r = tvtk.DataSetReader(read_from_input_string=1, input_string=d)
            warn = r.global_warning_display
            r.global_warning_display = 0
            r.update()
            r.global_warning_display = warn
            self.data = r.output
        # Now set the remaining state without touching the children.
        set_state(self, state, ignore=['children', 'data'])
        # Setup the children.
        handle_children_state(self.children, state.children)
        # Setup the children's state.
        set_state(self, state, first=['children'], ignore=['*'])

    ######################################################################
    # `Base` interface
    ######################################################################
    def start(self):
        """This is invoked when this object is added to the mayavi
        pipeline.
        """
        # Do nothing if we are already running.
        if self.running:
            return

        # Update the data just in case.
        self._update_data()

        # Call the parent method to do its thing.  This will typically
        # start all our children.
        super(VTKDataSource, self).start()

    def update(self):
        """Invoke this to flush data changes downstream.  This is
        typically used when you change the data object and want the
        mayavi pipeline to refresh.
        """
        # This tells the VTK pipeline that the data has changed.  This
        # will fire the data_changed event automatically.
        self.data.modified()
        self._assign_attribute.update()

    ######################################################################
    # Non-public interface
    ######################################################################
    def _data_changed(self, old, new):
        if has_attributes(self.data):
            aa = self._assign_attribute
            self.configure_input_data(aa, new)
            self._update_data()
            aa.update()
            self.outputs = [aa.output]
        else:
            self.outputs = [self.data]
        self.data_changed = True

        self.output_info.datasets = \
                [get_tvtk_dataset_name(self.outputs[0])]

        # Add an observer to the VTK dataset after removing the one
        # for the old dataset.  We use the messenger to avoid an
        # uncollectable reference cycle.  See the
        # tvtk.messenger module documentation for details.
        if old is not None:
            old.remove_observer(self._observer_id)
        self._observer_id = new.add_observer('ModifiedEvent', messenger.send)
        new_vtk = tvtk.to_vtk(new)
        messenger.connect(new_vtk, 'ModifiedEvent', self._fire_data_changed)

        # Change our name so that our label on the tree is updated.
        self.name = self._get_name()

    def _fire_data_changed(self, *args):
        """Simply fire the `data_changed` event."""
        self.data_changed = True

    def _set_data_name(self, data_type, attr_type, value):
        if value is None:
            return

        dataset = self.data
        if len(value) == 0:
            # If the value is empty then we deactivate that attribute.
            d = getattr(dataset, attr_type + '_data')
            method = getattr(d, 'set_active_%s' % data_type)
            method(None)
            self.data_changed = True
            return

        aa = self._assign_attribute
        data = None
        if attr_type == 'point':
            data = dataset.point_data
        elif attr_type == 'cell':
            data = dataset.cell_data
        method = getattr(data, 'set_active_%s' % data_type)
        method(value)
        aa.assign(value, data_type.upper(), attr_type.upper() + '_DATA')
        if data_type == 'scalars' and dataset.is_a('vtkImageData'):
            # Set the scalar_type for image data, if not you can either
            # get garbage rendered or worse.
            s = getattr(dataset, attr_type + '_data').scalars
            r = s.range
            if is_old_pipeline():
                dataset.scalar_type = s.data_type
                aa.output.scalar_type = s.data_type
        aa.update()
        # Fire an event, so the changes propagate.
        self.data_changed = True

    def _point_scalars_name_changed(self, value):
        self._set_data_name('scalars', 'point', value)

    def _point_vectors_name_changed(self, value):
        self._set_data_name('vectors', 'point', value)

    def _point_tensors_name_changed(self, value):
        self._set_data_name('tensors', 'point', value)

    def _cell_scalars_name_changed(self, value):
        self._set_data_name('scalars', 'cell', value)

    def _cell_vectors_name_changed(self, value):
        self._set_data_name('vectors', 'cell', value)

    def _cell_tensors_name_changed(self, value):
        self._set_data_name('tensors', 'cell', value)

    def _update_data(self):
        if self.data is None:
            return
        pnt_attr, cell_attr = get_all_attributes(self.data)

        pd = self.data.point_data
        scalars = pd.scalars
        if self.data.is_a('vtkImageData') and scalars is not None:
            # For some reason getting the range of the scalars flushes
            # the data through to prevent some really strange errors
            # when using an ImagePlaneWidget.
            r = scalars.range
            if is_old_pipeline():
                self._assign_attribute.output.scalar_type = scalars.data_type
                self.data.scalar_type = scalars.data_type

        def _setup_data_traits(obj, attributes, d_type):
            """Given the object, the dict of the attributes from the
            `get_all_attributes` function and the data type
            (point/cell) data this will setup the object and the data.
            """
            attrs = ['scalars', 'vectors', 'tensors']
            aa = obj._assign_attribute
            data = getattr(obj.data, '%s_data' % d_type)
            for attr in attrs:
                values = attributes[attr]
                values.append('')
                setattr(obj, '_%s_%s_list' % (d_type, attr), values)
                if len(values) > 1:
                    default = getattr(obj, '%s_%s_name' % (d_type, attr))
                    if obj._first and len(default) == 0:
                        default = values[0]
                    getattr(data, 'set_active_%s' % attr)(default)
                    aa.assign(default, attr.upper(), d_type.upper() + '_DATA')
                    aa.update()
                    kw = {
                        '%s_%s_name' % (d_type, attr): default,
                        'trait_change_notify': False
                    }
                    obj.set(**kw)

        _setup_data_traits(self, pnt_attr, 'point')
        _setup_data_traits(self, cell_attr, 'cell')
        if self._first:
            self._first = False
        # Propagate the data changed event.
        self.data_changed = True

    def _get_name(self):
        """ Gets the name to display on the tree.
        """
        ret = "VTK Data (uninitialized)"
        if self.data is not None:
            typ = self.data.__class__.__name__
            ret = "VTK Data (%s)" % typ
        if '[Hidden]' in self.name:
            ret += ' [Hidden]'
        return ret
Example #5
0
class VTKXMLFileReader(FileDataSource):
    """A VTK XML file reader.  The reader supports all the different
    types of data sets.  This reader also supports a time series.
    Currently, this reader assumes that there is only one output that
    has configurable attributes.
    """

    # The version of this class.  Used for persistence.
    __version__ = 0

    ########################################
    # Dynamic traits: These traits are dynamic and are automatically
    # updated depending on the contents of the file.

    # The active point scalar name.  An empty string indicates that
    # the attribute is "deactivated".  This is useful when you have
    # both point and cell attributes and want to use cell data by
    # default.
    point_scalars_name = DEnum(values_name='_point_scalars_list',
                               desc='scalar point data attribute to use')
    # The active point vector name.
    point_vectors_name = DEnum(values_name='_point_vectors_list',
                               desc='vectors point data attribute to use')
    # The active point tensor name.
    point_tensors_name = DEnum(values_name='_point_tensors_list',
                               desc='tensor point data attribute to use')

    # The active cell scalar name.
    cell_scalars_name = DEnum(values_name='_cell_scalars_list',
                              desc='scalar cell data attribute to use')
    # The active cell vector name.
    cell_vectors_name = DEnum(values_name='_cell_vectors_list',
                              desc='vectors cell data attribute to use')
    # The active cell tensor name.
    cell_tensors_name = DEnum(values_name='_cell_tensors_list',
                              desc='tensor cell data attribute to use')
    ########################################

    # The VTK data file reader.
    reader = Instance(tvtk.XMLReader)

    # Information about what this object can produce.
    output_info = PipelineInfo(datasets=['any'],
                               attribute_types=['any'],
                               attributes=['any'])

    # Our view.
    view = View(
        Group(
            Include('time_step_group'),
            Item(name='point_scalars_name'),
            Item(name='point_vectors_name'),
            Item(name='point_tensors_name'),
            Item(name='cell_scalars_name'),
            Item(name='cell_vectors_name'),
            Item(name='cell_tensors_name'),
            Item(name='reader'),
        ))

    ########################################
    # Private traits.

    # These private traits store the list of available data
    # attributes.  The non-private traits use these lists internally.
    _point_scalars_list = List(Str)
    _point_vectors_list = List(Str)
    _point_tensors_list = List(Str)
    _cell_scalars_list = List(Str)
    _cell_vectors_list = List(Str)
    _cell_tensors_list = List(Str)

    # This filter allows us to change the attributes of the data
    # object and will ensure that the pipeline is properly taken care
    # of.  Directly setting the array in the VTK object will not do
    # this.
    _assign_attribute = Instance(tvtk.AssignAttribute,
                                 args=(),
                                 allow_none=False)

    # Toggles if this is the first time this object has been used.
    _first = Bool(True)

    ######################################################################
    # `object` interface
    ######################################################################
    def __get_pure_state__(self):
        d = super(VTKXMLFileReader, self).__get_pure_state__()
        for name in ('_assign_attribute', '_first'):
            d.pop(name, None)
        # Pickle the 'point_scalars_name' etc. since these are
        # properties and not in __dict__.
        attr = {}
        for name in ('point_scalars', 'point_vectors', 'point_tensors',
                     'cell_scalars', 'cell_vectors', 'cell_tensors'):
            d.pop('_' + name + '_list', None)
            d.pop('_' + name + '_name', None)
            x = name + '_name'
            attr[x] = getattr(self, x)
        d.update(attr)
        return d

    def __set_pure_state__(self, state):
        # The reader has its own file_name which needs to be fixed.
        state.reader.file_name = state.file_path.abs_pth
        # Now call the parent class to setup everything.
        super(VTKXMLFileReader, self).__set_pure_state__(state)

    ######################################################################
    # `Base` interface
    ######################################################################
    def start(self):
        """This is invoked when this object is added to the mayavi
        pipeline.
        """
        # Do nothing if we are already running.
        if self.running:
            return

        # Call the parent method to do its thing.  This will typically
        # start all our children.
        super(VTKXMLFileReader, self).start()

    def stop(self):
        """Invoked when this object is removed from the mayavi
        pipeline.
        """
        if not self.running:
            return

        # Call the parent method to do its thing.
        super(VTKXMLFileReader, self).stop()

    ######################################################################
    # `FileDataSource` interface
    ######################################################################
    def update(self):
        if len(self.file_path.get()) == 0:
            return
        reader = self.reader
        reader.update()
        self.render()

    def update_data(self):
        if len(self.file_path.get()) == 0:
            return
        self.reader.update()
        pnt_attr, cell_attr = get_all_attributes(self.reader.output)

        def _setup_data_traits(obj, attributes, d_type):
            """Given the object, the dict of the attributes from the
            `get_all_attributes` function and the data type
            (point/cell) data this will setup the object and the data.
            """
            attrs = ['scalars', 'vectors', 'tensors']
            aa = obj._assign_attribute
            data = getattr(obj.reader.output, '%s_data' % d_type)
            for attr in attrs:
                values = attributes[attr]
                values.append('')
                setattr(obj, '_%s_%s_list' % (d_type, attr), values)
                if len(values) > 1:
                    default = getattr(obj, '%s_%s_name' % (d_type, attr))
                    if obj._first and len(default) == 0:
                        default = values[0]
                    getattr(data, 'set_active_%s' % attr)(default)
                    aa.assign(default, attr.upper(), d_type.upper() + '_DATA')
                    aa.update()
                    kw = {
                        '%s_%s_name' % (d_type, attr): default,
                        'trait_change_notify': False
                    }
                    obj.set(**kw)

        _setup_data_traits(self, cell_attr, 'cell')
        _setup_data_traits(self, pnt_attr, 'point')
        if self._first:
            self._first = False
        # Propagate the data changed event.
        self.data_changed = True

    def has_output_port(self):
        """ Return True as the reader has output port."""
        return True

    def get_output_object(self):
        """ Return the reader output port."""
        return self.reader.output_port

    ######################################################################
    # Non-public interface
    ######################################################################
    def _file_path_changed(self, fpath):
        value = fpath.get()
        if len(value) == 0:
            return
        else:
            if self.reader is None:
                d_type = find_file_data_type(fpath.get())
                self.reader = eval('tvtk.XML%sReader()' % d_type)
            reader = self.reader
            reader.file_name = value
            reader.update()

            # Setup the outputs by resetting self.outputs.  Changing
            # the outputs automatically fires a pipeline_changed
            # event.
            try:
                n = reader.number_of_outputs
            except AttributeError:  # for VTK >= 4.5
                n = reader.number_of_output_ports
            outputs = []
            for i in range(n):
                outputs.append(reader.get_output(i))

            # FIXME: Only the first output goes through the assign
            # attribute filter.
            aa = self._assign_attribute
            self.configure_input_data(aa, outputs[0])
            outputs[0] = aa.output
            self.update_data()

            self.outputs = outputs

            # FIXME: The output info is only based on the first output.
            self.output_info.datasets = [get_tvtk_dataset_name(outputs[0])]

            # Change our name on the tree view
            self.name = self._get_name()

    def _set_data_name(self, data_type, attr_type, value):
        if value is None:
            return

        reader_output = self.reader.output
        if len(value) == 0:
            # If the value is empty then we deactivate that attribute.
            d = getattr(reader_output, attr_type + '_data')
            method = getattr(d, 'set_active_%s' % data_type)
            method(None)
            self.data_changed = True
            return

        aa = self._assign_attribute
        data = None
        if attr_type == 'point':
            data = reader_output.point_data
        elif attr_type == 'cell':
            data = reader_output.cell_data

        method = getattr(data, 'set_active_%s' % data_type)
        method(value)
        aa.assign(value, data_type.upper(), attr_type.upper() + '_DATA')
        aa.update()
        # Fire an event, so the changes propagate.
        self.data_changed = True

    def _point_scalars_name_changed(self, value):
        self._set_data_name('scalars', 'point', value)

    def _point_vectors_name_changed(self, value):
        self._set_data_name('vectors', 'point', value)

    def _point_tensors_name_changed(self, value):
        self._set_data_name('tensors', 'point', value)

    def _cell_scalars_name_changed(self, value):
        self._set_data_name('scalars', 'cell', value)

    def _cell_vectors_name_changed(self, value):
        self._set_data_name('vectors', 'cell', value)

    def _cell_tensors_name_changed(self, value):
        self._set_data_name('tensors', 'cell', value)

    def _get_name(self):
        """ Gets the name to display on the tree view.
        """
        fname = basename(self.file_path.get())
        ret = "VTK XML file (%s)" % fname
        if len(self.file_list) > 1:
            ret += " (timeseries)"
        if '[Hidden]' in self.name:
            ret += ' [Hidden]'

        return ret