Example #1
0
class FiducialsPanel(HasPrivateTraits):
    """Set fiducials on an MRI surface"""
    model = Instance(MRIHeadWithFiducialsModel)

    fid_file = DelegatesTo('model')
    fid_fname = DelegatesTo('model')
    lpa = DelegatesTo('model')
    nasion = DelegatesTo('model')
    rpa = DelegatesTo('model')
    can_save = DelegatesTo('model')
    can_save_as = DelegatesTo('model')
    can_reset = DelegatesTo('model')
    fid_ok = DelegatesTo('model')
    locked = DelegatesTo('model', 'lock_fiducials')

    set = Enum('LPA', 'Nasion', 'RPA')
    current_pos = Array(float, (1, 3))  # for editing

    save_as = Button(label='Save As...')
    save = Button(label='Save')
    reset_fid = Button(label="Reset to File")

    headview = Instance(HeadViewController)
    hsp_obj = Instance(SurfaceObject)

    picker = Instance(object)

    # the layout of the dialog created
    view = View(
        VGroup(Item('fid_file', label='Fiducials File'),
               Item('fid_fname', show_label=False, style='readonly'),
               Item('set', style='custom'),
               Item('current_pos', label='Pos'),
               HGroup(Item('save',
                           enabled_when='can_save',
                           tooltip="If a filename is currently "
                           "specified, save to that file, otherwise "
                           "save to the default file name"),
                      Item('save_as', enabled_when='can_save_as'),
                      Item('reset_fid', enabled_when='can_reset'),
                      show_labels=False),
               enabled_when="locked==False"))

    def __init__(self, *args, **kwargs):
        super(FiducialsPanel, self).__init__(*args, **kwargs)
        self.sync_trait('lpa', self, 'current_pos', mutual=True)

    def _reset_fid_fired(self):
        self.model.reset = True

    def _save_fired(self):
        self.model.save()

    def _save_as_fired(self):
        if self.fid_file:
            default_path = self.fid_file
        else:
            default_path = self.model.default_fid_fname

        dlg = FileDialog(action="save as",
                         wildcard=fid_wildcard,
                         default_path=default_path)
        dlg.open()
        if dlg.return_code != OK:
            return

        path = dlg.path
        if not path.endswith('.fif'):
            path = path + '.fif'
            if os.path.exists(path):
                answer = confirm(
                    None, "The file %r already exists. Should it "
                    "be replaced?", "Overwrite File?")
                if answer != YES:
                    return

        self.model.save(path)

    def _on_pick(self, picker):
        if self.locked:
            return

        self.picker = picker
        n_pos = len(picker.picked_positions)

        if n_pos == 0:
            logger.debug("GUI: picked empty location")
            return

        if picker.actor is self.hsp_obj.surf.actor.actor:
            idxs = []
            idx = None
            pt = [picker.pick_position]
        elif self.hsp_obj.surf.actor.actor in picker.actors:
            idxs = [
                i for i in range(n_pos)
                if picker.actors[i] is self.hsp_obj.surf.actor.actor
            ]
            idx = idxs[-1]
            pt = [picker.picked_positions[idx]]
        else:
            logger.debug("GUI: picked object other than MRI")

        def round_(x):
            return round(x, 3)

        poss = [map(round_, pos) for pos in picker.picked_positions]
        pos = map(round_, picker.pick_position)
        msg = ["Pick Event: %i picked_positions:" % n_pos]

        line = str(pos)
        if idx is None:
            line += " <-pick_position"
        msg.append(line)

        for i, pos in enumerate(poss):
            line = str(pos)
            if i == idx:
                line += " <- MRI mesh"
            elif i in idxs:
                line += " (<- also MRI mesh)"
            msg.append(line)
        logger.debug(os.linesep.join(msg))

        if self.set == 'Nasion':
            self.nasion = pt
        elif self.set == 'LPA':
            self.lpa = pt
        elif self.set == 'RPA':
            self.rpa = pt
        else:
            raise ValueError("set = %r" % self.set)

    @on_trait_change('set')
    def _on_set_change(self, obj, name, old, new):
        self.sync_trait(old.lower(),
                        self,
                        'current_pos',
                        mutual=True,
                        remove=True)
        self.sync_trait(new.lower(), self, 'current_pos', mutual=True)
        if new == 'Nasion':
            self.headview.front = True
        elif new == 'LPA':
            self.headview.left = True
        elif new == 'RPA':
            self.headview.right = True
Example #2
0
 def traits_view(self):
     editor = TabularEditor(adapter=DatumAdapter())
     v = View(Item('data', editor=editor, height=500, show_label=False), )
     return v
Example #3
0
 def traits_view(self):
     v = View(HGroup(Item('use', show_label=False,), Item('center'), Item('threshold'), Item('color', style='custom', show_label=False)))
     return v
Example #4
0
 def traits_view(self):
     v = View(
         Item('show_grids', label='Grid'),
         Item('fps'),
         Item('quality'))
     return v
Example #5
0
class FiducialsPanel(HasPrivateTraits):
    """Set fiducials on an MRI surface."""

    model = Instance(MRIHeadWithFiducialsModel)

    fid_file = DelegatesTo('model')
    fid_fname = DelegatesTo('model')
    lpa = DelegatesTo('model')
    nasion = DelegatesTo('model')
    rpa = DelegatesTo('model')
    can_save = DelegatesTo('model')
    can_save_as = DelegatesTo('model')
    can_reset = DelegatesTo('model')
    fid_ok = DelegatesTo('model')
    locked = DelegatesTo('model', 'lock_fiducials')

    set = Enum('LPA', 'Nasion', 'RPA')
    current_pos_mm = Array(float, (1, 3))

    save_as = Button(label='Save as...')
    save = Button(label='Save')
    reset_fid = Button(label=_RESET_LABEL)

    headview = Instance(HeadViewController)
    hsp_obj = Instance(SurfaceObject)

    picker = Instance(object)

    # the layout of the dialog created
    view = View(VGroup(
        HGroup(Item('fid_file', width=_MRI_FIDUCIALS_WIDTH,
                    tooltip='MRI fiducials file'), show_labels=False),
        HGroup(Item('set', width=_MRI_FIDUCIALS_WIDTH,
                    format_func=lambda x: x, style='custom',
                    tooltip=_SET_TOOLTIP), show_labels=False),
        HGroup(Item('current_pos_mm',
                    editor=ArrayEditor(width=_MM_WIDTH, format_func=_mm_fmt),
                    tooltip='MRI fiducial position (mm)'), show_labels=False),
        HGroup(Item('save', enabled_when='can_save',
                    tooltip="If a filename is currently specified, save to "
                    "that file, otherwise save to the default file name",
                    width=_BUTTON_WIDTH),
               Item('save_as', enabled_when='can_save_as',
                    width=_BUTTON_WIDTH),
               Item('reset_fid', enabled_when='can_reset', width=_RESET_WIDTH,
                    tooltip='Reset to file values (if available)'),
               show_labels=False),
        enabled_when="locked==False", show_labels=False), handler=SetHandler())

    def __init__(self, *args, **kwargs):  # noqa: D102
        super(FiducialsPanel, self).__init__(*args, **kwargs)

    @on_trait_change('current_pos_mm')
    def _update_pos(self):
        attr = self.set.lower()
        if not np.allclose(getattr(self, attr), self.current_pos_mm * 1e-3):
            setattr(self, attr, self.current_pos_mm * 1e-3)

    @on_trait_change('model:lpa')
    def _update_lpa(self, name):
        if self.set == 'LPA':
            self.current_pos_mm = self.lpa * 1000

    @on_trait_change('model:nasion')
    def _update_nasion(self, name):
        if self.set.lower() == 'Nasion':
            self.current_pos_mm = self.nasion * 1000

    @on_trait_change('model:rpa')
    def _update_rpa(self, name):
        if self.set.lower() == 'RPA':
            self.current_pos_mm = self.rpa * 1000

    def _reset_fid_fired(self):
        self.model.reset = True

    def _save_fired(self):
        self.model.save()

    def _save_as_fired(self):
        if self.fid_file:
            default_path = self.fid_file
        else:
            default_path = self.model.default_fid_fname

        dlg = FileDialog(action="save as", wildcard=fid_wildcard,
                         default_path=default_path)
        dlg.open()
        if dlg.return_code != OK:
            return

        path = dlg.path
        if not path.endswith('.fif'):
            path = path + '.fif'
            if os.path.exists(path):
                answer = confirm(None, "The file %r already exists. Should it "
                                 "be replaced?", "Overwrite File?")
                if answer != YES:
                    return

        self.model.save(path)

    def _on_pick(self, picker):
        if self.locked:
            return

        self.picker = picker
        n_pos = len(picker.picked_positions)

        if n_pos == 0:
            logger.debug("GUI: picked empty location")
            return

        if picker.actor is self.hsp_obj.surf.actor.actor:
            idxs = []
            idx = None
            pt = [picker.pick_position]
        elif self.hsp_obj.surf.actor.actor in picker.actors:
            idxs = [i for i in range(n_pos) if picker.actors[i] is
                    self.hsp_obj.surf.actor.actor]
            idx = idxs[-1]
            pt = [picker.picked_positions[idx]]
        else:
            logger.debug("GUI: picked object other than MRI")

        def round_(x):
            return round(x, 3)

        poss = [map(round_, pos) for pos in picker.picked_positions]
        pos = map(round_, picker.pick_position)
        msg = ["Pick Event: %i picked_positions:" % n_pos]

        line = str(pos)
        if idx is None:
            line += " <-pick_position"
        msg.append(line)

        for i, pos in enumerate(poss):
            line = str(pos)
            if i == idx:
                line += " <- MRI mesh"
            elif i in idxs:
                line += " (<- also MRI mesh)"
            msg.append(line)
        logger.debug('\n'.join(msg))

        if self.set == 'Nasion':
            self.nasion = pt
        elif self.set == 'LPA':
            self.lpa = pt
        elif self.set == 'RPA':
            self.rpa = pt
        else:
            raise ValueError("set = %r" % self.set)

    @on_trait_change('set')
    def _on_set_change(self, obj, name, old, new):
        if new == 'Nasion':
            self.current_pos_mm = self.nasion * 1000
            self.headview.front = True
        elif new == 'LPA':
            self.current_pos_mm = self.lpa * 1000
            self.headview.left = True
        elif new == 'RPA':
            self.current_pos_mm = self.rpa * 1000
            self.headview.right = True
class ImageReader(FileDataSource):
    """A Image file reader. The reader supports all the
    different types of Image files.
    """

    # The version of this class.  Used for persistence.
    __version__ = 0

    # The Image data file reader.
    reader = Instance(tvtk.Object, allow_none=False, record=True)

    # Information about what this object can produce.
    output_info = PipelineInfo(datasets=['image_data'])

    # Our view.
    view = View(Group(Include('time_step_group'),
                      Item(name='base_file_name'),
                      Item(name='reader', style='custom', resizable=True),
                      show_labels=False),
                resizable=True)

    ######################################################################
    # Private Traits
    _image_reader_dict = Dict(Str, Instance(tvtk.Object))

    ######################################################################
    # `object` interface
    ######################################################################
    def __init__(self, **traits):
        d = {
            'bmp': tvtk.BMPReader(),
            'jpg': tvtk.JPEGReader(),
            'png': tvtk.PNGReader(),
            'pnm': tvtk.PNMReader(),
            'dcm': tvtk.DICOMImageReader(),
            'tiff': tvtk.TIFFReader(),
            'ximg': tvtk.GESignaReader(),
            'dem': tvtk.DEMReader(),
            'mha': tvtk.MetaImageReader(),
            'mhd': tvtk.MetaImageReader(),
        }
        # Account for pre 5.2 VTk versions, without MINC reader
        if hasattr(tvtk, 'MINCImageReader'):
            d['mnc'] = tvtk.MINCImageReader()
        d['jpeg'] = d['jpg']
        self._image_reader_dict = d
        # Call parent class' init.
        super(ImageReader, self).__init__(**traits)

    def __set_pure_state__(self, state):
        # The reader has its own file_name which needs to be fixed.
        state.reader.file_name = state.file_path.abs_pth
        # Now call the parent class to setup everything.
        super(ImageReader, self).__set_pure_state__(state)

    ######################################################################
    # `FileDataSource` interface
    ######################################################################
    def update(self):
        self.reader.update()
        if len(self.file_path.get()) == 0:
            return
        self.render()

    def has_output_port(self):
        """ Return True as the reader has output port."""
        return True

    def get_output_object(self):
        """ Return the reader output port."""
        return self.reader.output_port

    ######################################################################
    # Non-public interface
    ######################################################################
    def _file_path_changed(self, fpath):
        value = fpath.get()
        if len(value) == 0:
            return
        # Extract the file extension
        splitname = value.strip().split('.')
        extension = splitname[-1].lower()
        # Select image reader based on file type
        old_reader = self.reader
        if extension in self._image_reader_dict:
            self.reader = self._image_reader_dict[extension]
        else:
            self.reader = tvtk.ImageReader()

        self.reader.file_name = value.strip()
        self.reader.update()
        self.reader.update_information()

        if old_reader is not None:
            old_reader.on_trait_change(self.render, remove=True)
        self.reader.on_trait_change(self.render)

        self.outputs = [self.reader]

        # Change our name on the tree view
        self.name = self._get_name()

    def _get_name(self):
        """ Returns the name to display on the tree view.  Note that
        this is not a property getter.
        """
        fname = basename(self.file_path.get())
        ret = "%s" % fname
        if len(self.file_list) > 1:
            ret += " (timeseries)"
        if '[Hidden]' in self.name:
            ret += ' [Hidden]'

        return ret
Example #7
0
class SyntaxChecker ( Saveable ):

    #---------------------------------------------------------------------------
    #  Trait definitions:
    #---------------------------------------------------------------------------

    # The name of the plugin:
    name = Str( 'Syntax Checker' )

    # The persistence id for this object:
    id = Str( 'etsdevtools.developer.tools.syntax_checker.state',
              save_state_id = True )

    # Should the syntax checker automatically go to the current syntax error?
    auto_goto = Bool( False, save_state = True )

    # Should a changed file be automatically reloaded:
    auto_load = Bool( True, save_state = True )

    # The name of the file currently being syntax checked:
    file_name = File( drop_file = DropFile( extensions = [ '.py' ],
                                    draggable = True,
                                    tooltip   = 'Drop a Python source file to '
                                          'syntax check it.\nDrag this file.' ),
                      connect   = 'to' )

    # The current source code being syntax checked:
    source = Str

    # The current syntax error message:
    error = Str

    # Current error line:
    error_line = Int

    # Current error column:
    error_column = Int

    # Current editor line:
    line = Int

    # Current editor column:
    column = Int

    # 'Go to' button:
    go_to = Button( 'Go To' )

    # Can the current file be saved?
    can_save = false

    #---------------------------------------------------------------------------
    #  Traits view definitions:
    #---------------------------------------------------------------------------

    traits_view = View(
        VGroup(
            TTitle( 'file_name' ),
            Item( 'source@',
                  editor = CodeEditor( selected_line = 'line' ) ),
            TTitle( 'error' ),
            HGroup(
                spring,
                Item( 'go_to',
                      show_label   = False,
                      enabled_when = '(error_line > 0) and (not auto_goto)' ),
            ),
            show_labels = False
        ),
        title = 'Syntax Checker'
    )

    options = View(
        VGroup(
            Item( 'auto_goto',
                  label = 'Automatically move cursor to syntax error'
            ),
            Item( 'auto_load',
                  label = 'Automatically reload externally changed files'
            ),
            show_left = False
        ),
        title   = 'Syntax Checker Options',
        id      = 'etsdevtools.developer.tools.syntax_checker.options',
        buttons = [ 'OK', 'Cancel' ]
    )

    #---------------------------------------------------------------------------
    #  Handles the 'auto_goto' trait being changed:
    #---------------------------------------------------------------------------

    def _auto_goto_changed ( self, auto_goto ):
        """ Handles the 'auto_goto' trait being changed.
        """
        if auto_goto and (self.error_line > 0):
            self._go_to_changed()

    #---------------------------------------------------------------------------
    #  Handles the 'Go To' button being clicked:
    #---------------------------------------------------------------------------

    def _go_to_changed ( self ):
        """ Handles the 'Go To' button being clicked.
        """
        self.line   = self.error_line
        self.column = self.error_column

    #---------------------------------------------------------------------------
    #  Handles the 'file_name' trait being changed:
    #---------------------------------------------------------------------------

    def _file_name_changed ( self, old_name, new_name ):
        """ Handles the 'file_name' trait being changed.
        """
        self._set_listener( old_name, True )
        self._set_listener( new_name, False )
        self._load_file_name( new_name )

    #---------------------------------------------------------------------------
    #  Handles the 'source' trait being changed:
    #---------------------------------------------------------------------------

    def _source_changed ( self, source ):
        """ Handles the 'source' trait being changed.
        """
        if self.can_save:
            if not self._dont_update:
                self.needs_save = True
                do_after( 750, self._syntax_check )
            else:
                self._syntax_check()

    #---------------------------------------------------------------------------
    #  Handles the current file being updated:
    #---------------------------------------------------------------------------

    def _file_changed ( self, file_name ):
        """ Handles the current file being updated.
        """
        if self.auto_load:
            self._load_file_name( file_name )

    #---------------------------------------------------------------------------
    #  Sets up/Removes a file watch on a specified file:
    #---------------------------------------------------------------------------

    def _set_listener ( self, file_name, remove ):
        """ Sets up/Removes a file watch on a specified file.
        """
        if exists( file_name ):
            file_watch.watch( self._file_changed, file_name, remove = remove )

    #---------------------------------------------------------------------------
    #  Loads a specified source file:
    #---------------------------------------------------------------------------

    def _load_file_name ( self, file_name ):
        """ Loads a specified source file.
        """
        self._dont_update = True
        self.can_save = True
        source        = read_file( file_name )
        if source is None:
            self.error    = 'Error reading file'
            self.can_save = False
            source     = ''
        self.source = source
        self._dont_update = False
        self.needs_save   = False

    #---------------------------------------------------------------------------
    #  Checks the current source for syntax errors:
    #---------------------------------------------------------------------------

    def _syntax_check ( self ):
        """ Checks the current source for syntax errors.
        """
        try:
            compile( self.source.replace( '\r\n', '\n' ),
                     self.file_name, 'exec' )
            self.error      = 'Syntactically correct'
            self.error_line = 0
        except SyntaxError, excp:
            self.error_line   = excp.lineno
            self.error_column = excp.offset + 1
            self.error        = '%s on line %d, column %d' % (
                                excp.msg, excp.lineno, self.error_column )
            if self.auto_goto:
                self._go_to_changed()
Example #8
0
def create_view(model_view, readonly=False, show_units=True):
    if show_units:
        columns = [
            ObjectColumn(name='name', label='Name', editable=False, width=0.3),
            ObjectColumn(name='units',
                         label='Units',
                         editable=False,
                         width=0.3),
            ObjectColumn(name='binding',
                         label='Value',
                         editable=True,
                         width=0.4),
        ]
    else:
        columns = [
            ObjectColumn(name='name', label='Name', editable=False, width=0.4),
            ObjectColumn(name='binding',
                         label='Value',
                         editable=True,
                         width=0.6),
        ]

    if readonly:
        code_editor_style = 'readonly'
    else:
        code_editor_style = 'simple'

    view = View(
        VSplit(
            HGroup(
                VGroup(
                    Label("Inputs"),
                    Item(
                        'object.model.inputs',
                        # minimum settings to get rid of
                        # toolbar at top of table.
                        editor=TableEditor(
                            columns=columns,
                            editable=True,
                            configurable=False,
                            sortable=False,
                            sort_model=True,
                            selection_bg_color='white',
                            selection_color='black',
                            label_bg_color=WindowColor,
                            cell_bg_color='white',
                        ),
                        show_label=False,
                    ),
                ),
                VGroup(
                    Label("Outputs"),
                    Item(
                        'object.model.outputs',
                        # minimum settings to get rid of
                        # toolbar at top of table.
                        editor=TableEditor(
                            columns=columns,
                            editable=True,
                            configurable=False,
                            sortable=False,
                            sort_model=True,
                            selection_bg_color='white',
                            selection_color='black',
                            label_bg_color=WindowColor,
                            cell_bg_color='white',
                        ),
                        show_label=False,
                    ),
                ),
            ),
            Group(
                Item('object.model.code',
                     editor=CodeEditor(),
                     style=code_editor_style,
                     show_label=False), ),
        ),
        model_view=model_view,
        width=720,  # about 80 columns wide on code view.
        height=700,
        resizable=True,
        buttons=menu.OKCancelButtons,
        close_result=False,
    )

    return view
Example #9
0
 def traits_view(self):
     v = View(Item('nominal_hv', format_str='%0.4f'),
              Item('current_hv', format_str='%0.4f', style='readonly'))
     return v
Example #10
0
class MarkerPointDest(MarkerPoints):  # noqa: D401
    """MarkerPoints subclass that serves for derived points."""

    src1 = Instance(MarkerPointSource)
    src2 = Instance(MarkerPointSource)

    name = Property(Str, depends_on='src1.name,src2.name')
    dir = Property(Str, depends_on='src1.dir,src2.dir')

    points = Property(Array(float, (5, 3)),
                      depends_on=[
                          'method', 'src1.points', 'src1.use', 'src2.points',
                          'src2.use'
                      ])
    enabled = Property(Bool, depends_on=['points'])

    method = Enum('Transform',
                  'Average',
                  desc="Transform: estimate a rotation"
                  "/translation from mrk1 to mrk2; Average: use the average "
                  "of the mrk1 and mrk2 coordinates for each point.")

    view = View(
        VGroup(Item('method', style='custom'),
               Item('save_as', enabled_when='can_save', show_label=False)))

    @cached_property
    def _get_dir(self):
        return self.src1.dir

    @cached_property
    def _get_name(self):
        n1 = self.src1.name
        n2 = self.src2.name

        if not n1:
            if n2:
                return n2
            else:
                return ''
        elif not n2:
            return n1

        if n1 == n2:
            return n1

        i = 0
        l1 = len(n1) - 1
        l2 = len(n1) - 2
        while n1[i] == n2[i]:
            if i == l1:
                return n1
            elif i == l2:
                return n2

            i += 1

        return n1[:i]

    @cached_property
    def _get_enabled(self):
        return np.any(self.points)

    @cached_property
    def _get_points(self):
        # in case only one or no source is enabled
        if not (self.src1 and self.src1.enabled):
            if (self.src2 and self.src2.enabled):
                return self.src2.points
            else:
                return np.zeros((5, 3))
        elif not (self.src2 and self.src2.enabled):
            return self.src1.points

        # Average method
        if self.method == 'Average':
            if len(np.union1d(self.src1.use, self.src2.use)) < 5:
                error(None, "Need at least one source for each point.",
                      "Marker Average Error")
                return np.zeros((5, 3))

            pts = (self.src1.points + self.src2.points) / 2.
            for i in np.setdiff1d(self.src1.use, self.src2.use):
                pts[i] = self.src1.points[i]
            for i in np.setdiff1d(self.src2.use, self.src1.use):
                pts[i] = self.src2.points[i]

            return pts

        # Transform method
        idx = np.intersect1d(np.array(self.src1.use),
                             np.array(self.src2.use),
                             assume_unique=True)
        if len(idx) < 3:
            error(None, "Need at least three shared points for trans"
                  "formation.", "Marker Interpolation Error")
            return np.zeros((5, 3))

        src_pts = self.src1.points[idx]
        tgt_pts = self.src2.points[idx]
        est = fit_matched_points(src_pts, tgt_pts, out='params')
        rot = np.array(est[:3]) / 2.
        tra = np.array(est[3:]) / 2.

        if len(self.src1.use) == 5:
            trans = np.dot(translation(*tra), rotation(*rot))
            pts = apply_trans(trans, self.src1.points)
        elif len(self.src2.use) == 5:
            trans = np.dot(translation(*-tra), rotation(*-rot))
            pts = apply_trans(trans, self.src2.points)
        else:
            trans1 = np.dot(translation(*tra), rotation(*rot))
            pts = apply_trans(trans1, self.src1.points)
            trans2 = np.dot(translation(*-tra), rotation(*-rot))
            for i in np.setdiff1d(self.src2.use, self.src1.use):
                pts[i] = apply_trans(trans2, self.src2.points[i])

        return pts
Example #11
0
class CombineMarkersPanel(HasTraits):  # noqa: D401
    """Has two marker points sources and interpolates to a third one."""

    model = Instance(CombineMarkersModel, ())

    # model references for UI
    mrk1 = Instance(MarkerPointSource)
    mrk2 = Instance(MarkerPointSource)
    mrk3 = Instance(MarkerPointDest)
    distance = Str

    # Visualization
    scene = Instance(MlabSceneModel)
    scale = Float(5e-3)
    mrk1_obj = Instance(PointObject)
    mrk2_obj = Instance(PointObject)
    mrk3_obj = Instance(PointObject)
    trans = Array()

    view = View(
        VGroup(
            VGroup(Item('mrk1', style='custom'),
                   Item('mrk1_obj', style='custom'),
                   show_labels=False,
                   label="Source Marker 1",
                   show_border=True),
            VGroup(Item('mrk2', style='custom'),
                   Item('mrk2_obj', style='custom'),
                   show_labels=False,
                   label="Source Marker 2",
                   show_border=True),
            VGroup(Item('distance', style='readonly'),
                   label='Stats',
                   show_border=True),
            VGroup(Item('mrk3', style='custom'),
                   Item('mrk3_obj', style='custom'),
                   show_labels=False,
                   label="New Marker",
                   show_border=True),
        ))

    def _mrk1_default(self):
        return self.model.mrk1

    def _mrk2_default(self):
        return self.model.mrk2

    def _mrk3_default(self):
        return self.model.mrk3

    def __init__(self, *args, **kwargs):  # noqa: D102
        super(CombineMarkersPanel, self).__init__(*args, **kwargs)

        self.model.sync_trait('distance', self, 'distance', mutual=False)

        self.mrk1_obj = PointObject(scene=self.scene,
                                    color=(0.608, 0.216, 0.216),
                                    point_scale=self.scale)
        self.model.mrk1.sync_trait('enabled',
                                   self.mrk1_obj,
                                   'visible',
                                   mutual=False)

        self.mrk2_obj = PointObject(scene=self.scene,
                                    color=(0.216, 0.608, 0.216),
                                    point_scale=self.scale)
        self.model.mrk2.sync_trait('enabled',
                                   self.mrk2_obj,
                                   'visible',
                                   mutual=False)

        self.mrk3_obj = PointObject(scene=self.scene,
                                    color=(0.588, 0.784, 1.),
                                    point_scale=self.scale)
        self.model.mrk3.sync_trait('enabled',
                                   self.mrk3_obj,
                                   'visible',
                                   mutual=False)

    @on_trait_change('model:mrk1:points,trans')
    def _update_mrk1(self):
        if self.mrk1_obj is not None:
            self.mrk1_obj.points = apply_trans(self.trans,
                                               self.model.mrk1.points)

    @on_trait_change('model:mrk2:points,trans')
    def _update_mrk2(self):
        if self.mrk2_obj is not None:
            self.mrk2_obj.points = apply_trans(self.trans,
                                               self.model.mrk2.points)

    @on_trait_change('model:mrk3:points,trans')
    def _update_mrk3(self):
        if self.mrk3_obj is not None:
            self.mrk3_obj.points = apply_trans(self.trans,
                                               self.model.mrk3.points)
Example #12
0
else:
    if sys.platform in ('win32', 'linux2'):
        # on Windows and Ubuntu, multiple wildcards does not seem to work
        mrk_wildcard = ["*.sqd", "*.mrk", "*.txt", "*.pickled"]
    else:
        mrk_wildcard = ["*.sqd;*.mrk;*.txt;*.pickled"]
    mrk_out_wildcard = "*.txt"
out_ext = '.txt'

use_editor_v = CheckListEditor(cols=1, values=[(i, str(i)) for i in range(5)])
use_editor_h = CheckListEditor(cols=5, values=[(i, str(i)) for i in range(5)])

mrk_view_editable = View(
    VGroup(
        'file',
        Item('name', show_label=False, style='readonly'),
        HGroup(
            Item('use',
                 editor=use_editor_v,
                 enabled_when="enabled",
                 style='custom'),
            'points',
        ),
        HGroup(Item('clear', enabled_when="can_save", show_label=False),
               Item('save_as', enabled_when="can_save", show_label=False)),
    ))

mrk_view_basic = View(
    VGroup(
        'file',
        Item('name', show_label=False, style='readonly'),
Example #13
0
class Volume(Module):
    """The Volume module visualizes scalar fields using volumetric
    visualization techniques.  This supports ImageData and
    UnstructuredGrid data.  It also supports the FixedPointRenderer
    for ImageData.  However, the performance is slow so your best bet
    is probably with the ImageData based renderers.
    """

    # The version of this class.  Used for persistence.
    __version__ = 0

    volume_mapper_type = DEnum(values_name='_mapper_types',
                               desc='volume mapper to use')

    ray_cast_function_type = DEnum(values_name='_ray_cast_functions',
                                   desc='Ray cast function to use')

    volume = ReadOnly

    volume_mapper = Property(record=True)

    volume_property = Property(record=True)

    ray_cast_function = Property(record=True)

    lut_manager = Instance(VolumeLUTManager,
                           args=(),
                           allow_none=False,
                           record=True)

    input_info = PipelineInfo(datasets=['image_data', 'unstructured_grid'],
                              attribute_types=['any'],
                              attributes=['scalars'])

    ########################################
    # View related code.

    update_ctf = Button('Update CTF')

    view = View(Group(Item(name='_volume_property',
                           style='custom',
                           editor=CustomEditor(gradient_editor_factory),
                           resizable=True),
                      Item(name='update_ctf'),
                      label='CTF',
                      show_labels=False),
                Group(
                    Item(name='volume_mapper_type'),
                    Group(Item(name='_volume_mapper',
                               style='custom',
                               resizable=True),
                          show_labels=False),
                    Item(name='ray_cast_function_type'),
                    Group(Item(name='_ray_cast_function',
                               enabled_when='len(_ray_cast_functions) > 0',
                               style='custom',
                               resizable=True),
                          show_labels=False),
                    label='Mapper',
                ),
                Group(Item(name='_volume_property',
                           style='custom',
                           resizable=True),
                      label='Property',
                      show_labels=False),
                Group(Item(name='volume',
                           style='custom',
                           editor=InstanceEditor(),
                           resizable=True),
                      label='Volume',
                      show_labels=False),
                Group(Item(name='lut_manager', style='custom', resizable=True),
                      label='Legend',
                      show_labels=False),
                resizable=True)

    ########################################
    # Private traits
    _volume_mapper = Instance(tvtk.AbstractVolumeMapper)
    _volume_property = Instance(tvtk.VolumeProperty)
    _ray_cast_function = Instance(tvtk.Object)

    _mapper_types = List(Str, [
        'TextureMapper2D',
        'RayCastMapper',
    ])

    _available_mapper_types = List(Str)

    _ray_cast_functions = List(Str)

    current_range = Tuple

    # The color transfer function.
    _ctf = Instance(ColorTransferFunction)
    # The opacity values.
    _otf = Instance(PiecewiseFunction)

    ######################################################################
    # `object` interface
    ######################################################################
    def __get_pure_state__(self):
        d = super(Volume, self).__get_pure_state__()
        d['ctf_state'] = save_ctfs(self._volume_property)
        for name in ('current_range', '_ctf', '_otf'):
            d.pop(name, None)
        return d

    def __set_pure_state__(self, state):
        self.volume_mapper_type = state['_volume_mapper_type']
        state_pickler.set_state(self, state, ignore=['ctf_state'])
        ctf_state = state['ctf_state']
        ctf, otf = load_ctfs(ctf_state, self._volume_property)
        self._ctf = ctf
        self._otf = otf
        self._update_ctf_fired()

    ######################################################################
    # `Module` interface
    ######################################################################
    def start(self):
        super(Volume, self).start()
        self.lut_manager.start()

    def stop(self):
        super(Volume, self).stop()
        self.lut_manager.stop()

    def setup_pipeline(self):
        """Override this method so that it *creates* the tvtk
        pipeline.
        """
        v = self.volume = tvtk.Volume()
        vp = self._volume_property = tvtk.VolumeProperty()

        self._ctf = ctf = default_CTF(0, 255)
        self._otf = otf = default_OTF(0, 255)
        vp.set_color(ctf)
        vp.set_scalar_opacity(otf)
        vp.shade = True
        vp.interpolation_type = 'linear'
        v.property = vp

        v.on_trait_change(self.render)
        vp.on_trait_change(self.render)

        available_mappers = find_volume_mappers()
        if is_volume_pro_available():
            self._mapper_types.append('VolumeProMapper')
            available_mappers.append('VolumeProMapper')

        self._available_mapper_types = available_mappers
        if 'FixedPointVolumeRayCastMapper' in available_mappers:
            self._mapper_types.append('FixedPointVolumeRayCastMapper')

        self.actors.append(v)

    def update_pipeline(self):
        """Override this method so that it *updates* the tvtk pipeline
        when data upstream is known to have changed.

        This method is invoked (automatically) when any of the inputs
        sends a `pipeline_changed` event.
        """
        mm = self.module_manager
        if mm is None:
            return

        input = mm.source.outputs[0]

        ug = hasattr(tvtk, 'UnstructuredGridVolumeMapper')
        if ug:
            if not input.is_a('vtkImageData') \
                   and not input.is_a('vtkUnstructuredGrid'):
                error('Volume rendering only works with '\
                      'StructuredPoints/ImageData/UnstructuredGrid datasets')
                return
        elif not input.is_a('vtkImageData'):
            error('Volume rendering only works with '\
                  'StructuredPoints/ImageData datasets')
            return

        self._setup_mapper_types()
        self._setup_current_range()
        self._volume_mapper_type_changed(self.volume_mapper_type)
        self._update_ctf_fired()
        self.pipeline_changed = True

    def update_data(self):
        """Override this method so that it flushes the vtk pipeline if
        that is necessary.

        This method is invoked (automatically) when any of the inputs
        sends a `data_changed` event.
        """
        self._setup_mapper_types()
        self._setup_current_range()
        self._update_ctf_fired()
        self.data_changed = True

    ######################################################################
    # Non-public methods.
    ######################################################################
    def _setup_mapper_types(self):
        """Sets up the mapper based on input data types.
        """
        input = self.module_manager.source.outputs[0]
        if input.is_a('vtkUnstructuredGrid'):
            if hasattr(tvtk, 'UnstructuredGridVolumeMapper'):
                check = [
                    'UnstructuredGridVolumeZSweepMapper',
                    'UnstructuredGridVolumeRayCastMapper',
                ]
                mapper_types = []
                for mapper in check:
                    if mapper in self._available_mapper_types:
                        mapper_types.append(mapper)
                if len(mapper_types) == 0:
                    mapper_types = ['']
                self._mapper_types = mapper_types
                return
        else:
            if input.point_data.scalars.data_type not in \
               [vtkConstants.VTK_UNSIGNED_CHAR,
                vtkConstants.VTK_UNSIGNED_SHORT]:
                if 'FixedPointVolumeRayCastMapper' \
                       in self._available_mapper_types:
                    self._mapper_types = ['FixedPointVolumeRayCastMapper']
                else:
                    error('Available volume mappers only work with \
                    unsigned_char or unsigned_short datatypes')
            else:
                mapper_types = ['TextureMapper2D', 'RayCastMapper']
                check = ['FixedPointVolumeRayCastMapper', 'VolumeProMapper']
                for mapper in check:
                    if mapper in self._available_mapper_types:
                        mapper_types.append(mapper)
                self._mapper_types = mapper_types

    def _setup_current_range(self):
        mm = self.module_manager
        # Set the default name and range for our lut.
        lm = self.lut_manager
        slm = mm.scalar_lut_manager
        lm.set(default_data_name=slm.default_data_name,
               default_data_range=slm.default_data_range)

        # Set the current range.
        input = mm.source.outputs[0]
        sc = input.point_data.scalars
        if sc is not None:
            rng = sc.range
        else:
            error('No scalars in input data!')
            rng = (0, 255)

        if self.current_range != rng:
            self.current_range = rng

    def _get_volume_mapper(self):
        return self._volume_mapper

    def _get_volume_property(self):
        return self._volume_property

    def _get_ray_cast_function(self):
        return self._ray_cast_function

    def _volume_mapper_type_changed(self, value):
        mm = self.module_manager
        if mm is None:
            return

        old_vm = self._volume_mapper
        if old_vm is not None:
            old_vm.on_trait_change(self.render, remove=True)

        if value == 'RayCastMapper':
            new_vm = tvtk.VolumeRayCastMapper()
            self._volume_mapper = new_vm
            self._ray_cast_functions = [
                'RayCastCompositeFunction', 'RayCastMIPFunction',
                'RayCastIsosurfaceFunction'
            ]
            new_vm.volume_ray_cast_function = tvtk.VolumeRayCastCompositeFunction(
            )
        elif value == 'TextureMapper2D':
            new_vm = tvtk.VolumeTextureMapper2D()
            self._volume_mapper = new_vm
            self._ray_cast_functions = ['']
        elif value == 'VolumeProMapper':
            new_vm = tvtk.VolumeProMapper()
            self._volume_mapper = new_vm
            self._ray_cast_functions = ['']
        elif value == 'FixedPointVolumeRayCastMapper':
            new_vm = tvtk.FixedPointVolumeRayCastMapper()
            self._volume_mapper = new_vm
            self._ray_cast_functions = ['']
        elif value == 'UnstructuredGridVolumeRayCastMapper':
            new_vm = tvtk.UnstructuredGridVolumeRayCastMapper()
            self._volume_mapper = new_vm
            self._ray_cast_functions = ['']
        elif value == 'UnstructuredGridVolumeZSweepMapper':
            new_vm = tvtk.UnstructuredGridVolumeZSweepMapper()
            self._volume_mapper = new_vm
            self._ray_cast_functions = ['']

        new_vm.input = mm.source.outputs[0]
        self.volume.mapper = new_vm
        new_vm.on_trait_change(self.render)

    def _update_ctf_fired(self):
        set_lut(self.lut_manager.lut, self._volume_property)
        self.render()

    def _current_range_changed(self, old, new):
        rescale_ctfs(self._volume_property, new)
        self.render()

    def _ray_cast_function_type_changed(self, old, new):
        rcf = self.ray_cast_function
        if len(old) > 0:
            rcf.on_trait_change(self.render, remove=True)

        if len(new) > 0:
            new_rcf = getattr(tvtk, 'Volume%s' % new)()
            new_rcf.on_trait_change(self.render)
            self._volume_mapper.volume_ray_cast_function = new_rcf
            self._ray_cast_function = new_rcf
        else:
            self._ray_cast_function = None

        self.render()

    def _scene_changed(self, old, new):
        super(Volume, self)._scene_changed(old, new)
        self.lut_manager.scene = new
Example #14
0
class BMCSLauncher(HasTraits):

    version = Constant(CURRENT_VERSION)
    #=========================================================================
    # Lecture #2
    #=========================================================================
    bond_slip_model_d = Button(label='Damage')

    def _bond_slip_model_d_fired(self):
        run_bond_sim_damage()

    bond_slip_model_p = Button(label='Plasticity')

    def _bond_slip_model_p_fired(self):
        run_bond_sim_elasto_plasticity()

    bond_slip_model_dp = Button(label='Damage-plasticity')

    def _bond_slip_model_dp_fired(self):
        pass
        # run_bond_slip_model_dp(kind='live')

    #=========================================================================
    # Lecture #3
    #=========================================================================
    pullout_model_const_shear = Button(label='Constant shear')

    def _pullout_model_const_shear_fired(self):
        run_pullout_const_shear(kind='live')

    pullout_model_multilinear = Button(label='Multi-linear')

    def _pullout_model_multilinear_fired(self):
        run_pullout_multilinear(kind='live')

    pullout_model_frp_damage = Button(label='FRP damage')

    def _pullout_model_frp_damage_fired(self):
        run_pullout_frp_damage(kind='live')

    pullout_model_ep = Button(label='Elasto-plasticity')

    def _pullout_model_ep_fired(self):
        run_pullout_ep_cyclic()

    pullout_model_dp = Button(label='Damage-plasticity')

    def _pullout_model_dp_fired(self):
        run_pullout_dp(kind='live')

    pullout_model_fatigue = Button(label='Damage-fatigue')

    def _pullout_model_fatigue_fired(self):
        run_pullout_fatigue(kind='live')

    #=========================================================================
    # Lecture #8
    #=========================================================================

    tensile_test_2d_sdamage = Button(label='Tensile test - isotropic damage')

    def _tensile_test_2d_sdamage_fired(self):
        run_tension_sdamage_viz3d(kind='live')

    bending3pt_2d_sdamage_viz2d = Button(
        label='bending test 3Pt - isotropic damage (2D-light)')

    def _bending3pt_2d_sdamage_viz2d_fired(self):
        run_bending3pt_sdamage_viz2d(kind='live')

    bending3pt_3d = Button(label='Bending test (3D)')

    def _bending3pt_3d_fired(self):
        run_bending3pt_mic_odf(kind='live')

    bending3pt_2d_sdamage_viz2d = Button(
        label='bending test 3Pt - isotropic damage (2D-light)')

    def _bending3pt_2d_sdamage_viz2d_fired(self):
        run_bending3pt_sdamage_viz2d(kind='live')

    bending3pt_2d_sdamage_viz3d = Button(
        label='Bending test 3Pt - isotropic damage (2D-heavy)')

    def _bending3pt_2d_sdamage_viz3d_fired(self):
        run_bending3pt_sdamage_viz3d(kind='live')

    #=========================================================================
    # Lecture #6
    #=========================================================================
    yc_explorer = Button(label='Yield conditions for concrete')

    def _yc_explorer_fired(self):
        run_explorer(kind='live')

    view = View(VGroup(
        HGroup(
            Spring(),
            Item(
                'version',
                style='readonly',
                full_size=True,
                resizable=True,
            )),
        Group(UItem('bond_slip_model_d',
                    full_size=True,
                    resizable=True,
                    enabled_when='True'),
              UItem('bond_slip_model_p',
                    full_size=True,
                    resizable=True,
                    enabled_when='True'),
              UItem('bond_slip_model_dp',
                    full_size=True,
                    resizable=True,
                    enabled_when='False'),
              label='Bond-slip models, lecture #1-2'),
        Group(UItem('pullout_model_const_shear',
                    full_size=True,
                    resizable=True,
                    enabled_when='True'),
              UItem('pullout_model_multilinear',
                    full_size=True,
                    resizable=True,
                    enabled_when='True'),
              UItem('pullout_model_ep',
                    full_size=True,
                    resizable=True,
                    enabled_when='True'),
              UItem('pullout_model_frp_damage',
                    full_size=True,
                    resizable=True,
                    enabled_when='True'),
              UItem('pullout_model_dp',
                    full_size=True,
                    resizable=True,
                    enabled_when='True'),
              UItem('pullout_model_fatigue',
                    full_size=True,
                    resizable=True,
                    enabled_when='True'),
              label='Pull-out models, lecture #3-6'),
        Group(UItem('tensile_test_2d_sdamage',
                    full_size=True,
                    resizable=True,
                    enabled_when='True'),
              UItem('bending3pt_2d_sdamage_viz2d',
                    full_size=True,
                    resizable=True,
                    enabled_when='True'),
              UItem('bending3pt_2d_sdamage_viz3d',
                    full_size=True,
                    resizable=True,
                    enabled_when='True'),
              UItem('bending3pt_3d',
                    full_size=True,
                    resizable=True,
                    enabled_when='True'),
              label='Bending, crack propagation, lecture #7-9'),
        Group(UItem('yc_explorer', full_size=True, resizable=True),
              label='Yield surface explorer #10'),
    ),
                title='BMCS application launcher',
                width=500,
                buttons=['OK'])
class ExtractGrid(FilterBase):
    """This filter enables one to select a portion of, or subsample an
    input dataset which may be a StructuredPoints, StructuredGrid or
    Rectilinear.
    """
    # The version of this class.  Used for persistence.
    __version__ = 0

    # Minimum x value.
    x_min = Range(value=0,
                  low='_x_low',
                  high='_x_high',
                  enter_set=True,
                  auto_set=False,
                  desc='minimum x value of the domain')

    # Maximum x value.
    x_max = Range(value=10000,
                  low='_x_low',
                  high='_x_high',
                  enter_set=True,
                  auto_set=False,
                  desc='maximum x value of the domain')

    # Minimum y value.
    y_min = Range(value=0,
                  low='_y_low',
                  high='_y_high',
                  enter_set=True,
                  auto_set=False,
                  desc='minimum y value of the domain')

    # Maximum y value.
    y_max = Range(value=10000,
                  low='_y_low',
                  high='_y_high',
                  enter_set=True,
                  auto_set=False,
                  desc='maximum y value of the domain')

    # Minimum z value.
    z_min = Range(value=0,
                  low='_z_low',
                  high='_z_high',
                  enter_set=True,
                  auto_set=False,
                  desc='minimum z value of the domain')

    # Maximum z value.
    z_max = Range(value=10000,
                  low='_z_low',
                  high='_z_high',
                  enter_set=True,
                  auto_set=False,
                  desc='maximum z value of the domain')

    # Sample rate in x.
    x_ratio = Range(value=1,
                    low='_min_sample',
                    high='_x_s_high',
                    enter_set=True,
                    auto_set=False,
                    desc='sample rate along x')

    # Sample rate in y.
    y_ratio = Range(value=1,
                    low='_min_sample',
                    high='_y_s_high',
                    enter_set=True,
                    auto_set=False,
                    desc='sample rate along y')

    # Sample rate in z.
    z_ratio = Range(value=1,
                    low='_min_sample',
                    high='_z_s_high',
                    enter_set=True,
                    auto_set=False,
                    desc='sample rate along z')

    # The actual TVTK filter that this class manages.
    filter = Instance(tvtk.Object, tvtk.ExtractVOI(), allow_none=False)

    input_info = PipelineInfo(
        datasets=['image_data', 'rectilinear_grid', 'structured_grid'],
        attribute_types=['any'],
        attributes=['any'])

    output_info = PipelineInfo(
        datasets=['image_data', 'rectilinear_grid', 'structured_grid'],
        attribute_types=['any'],
        attributes=['any'])

    ########################################
    # Private traits.

    # Determines the lower/upper limit of the axes for the sliders.
    _min_sample = Int(1)
    _x_low = Int(0)
    _x_high = Int(10000)
    _x_s_high = Int(100)
    _y_low = Int(0)
    _y_high = Int(10000)
    _y_s_high = Int(100)
    _z_low = Int(0)
    _z_high = Int(10000)
    _z_s_high = Int(100)

    ########################################
    # View related traits.

    # The View for this object.
    view = View(
        Group(Item(label='Select Volume Of Interest'),
              Item(name='x_min'),
              Item(name='x_max'),
              Item(name='y_min'),
              Item(name='y_max'),
              Item(name='z_min'),
              Item(name='z_max'),
              Item('_'),
              Item(label='Select Sample Ratio'),
              Item(name='x_ratio'),
              Item(name='y_ratio'),
              Item(name='z_ratio'),
              label='VOI'),
        Group(Item(name='filter', style='custom', resizable=True),
              show_labels=False,
              label='Filter'),
        resizable=True,
    )

    ######################################################################
    # `object` interface
    ######################################################################
    def __get_pure_state__(self):
        d = super(ExtractGrid, self).__get_pure_state__()
        for axis in ('x', 'y', 'z'):
            for name in ('_min', '_max'):
                d.pop(axis + name, None)
            d.pop('_' + axis + '_low', None)
            d.pop('_' + axis + '_high', None)
            d.pop('_' + axis + '_s_high', None)
            d.pop(axis + '_ratio', None)
        return d

    ######################################################################
    # `Filter` interface
    ######################################################################
    def update_pipeline(self):
        inputs = self.inputs
        if len(inputs) == 0:
            return

        input = inputs[0].get_output_dataset()
        mapping = {
            'vtkStructuredGrid': tvtk.ExtractGrid,
            'vtkRectilinearGrid': tvtk.ExtractRectilinearGrid,
            'vtkImageData': tvtk.ExtractVOI
        }

        for key, klass in mapping.items():
            if input.is_a(key):
                self.filter = klass()
                break
        else:
            error('This filter does not support %s objects'%\
                  (input.__class__.__name__))
            return

        fil = self.filter
        self.configure_connection(fil, inputs[0])
        self._update_limits()
        self._update_voi()
        self._update_sample_rate()
        fil.update()
        self._set_outputs([fil])

    def update_data(self):
        """This method is invoked (automatically) when any of the
        inputs sends a `data_changed` event.
        """
        self._update_limits()
        fil = self.filter
        fil.update_whole_extent()
        fil.update()
        # Propagate the data_changed event.
        self.data_changed = True

    ######################################################################
    # Non-public methods.
    ######################################################################
    def _update_limits(self):
        if is_old_pipeline():
            extents = self.filter.input.whole_extent
        elif VTK_MAJOR_VERSION <= 7:
            extents = self.filter.get_update_extent()
        else:
            extents = self.filter.input.extent

        if (extents[0] > extents[1] or extents[2] > extents[3]
                or extents[4] > extents[5]):
            dims = self.inputs[0].get_output_dataset().dimensions
            e = extents
            extents = [e[0], dims[0] - 1, e[2], dims[1] - 1, e[4], dims[2] - 1]

        self._x_low, self._x_high = extents[:2]
        self._y_low, self._y_high = extents[2:4]
        self._z_low, self._z_high = extents[4:]
        self._x_s_high = max(1, self._x_high)
        self._y_s_high = max(1, self._y_high)
        self._z_s_high = max(1, self._z_high)

    def _x_min_changed(self, val):
        if val > self.x_max:
            self.x_max = val
        else:
            self._update_voi()

    def _x_max_changed(self, val):
        if val < self.x_min:
            self.x_min = val
        else:
            self._update_voi()

    def _y_min_changed(self, val):
        if val > self.y_max:
            self.y_max = val
        else:
            self._update_voi()

    def _y_max_changed(self, val):
        if val < self.y_min:
            self.y_min = val
        else:
            self._update_voi()

    def _z_min_changed(self, val):
        if val > self.z_max:
            self.z_max = val
        else:
            self._update_voi()

    def _z_max_changed(self, val):
        if val < self.z_min:
            self.z_min = val
        else:
            self._update_voi()

    def _x_ratio_changed(self):
        self._update_sample_rate()

    def _y_ratio_changed(self):
        self._update_sample_rate()

    def _z_ratio_changed(self):
        self._update_sample_rate()

    def _update_voi(self):
        f = self.filter
        f.voi = (self.x_min, self.x_max, self.y_min, self.y_max, self.z_min,
                 self.z_max)
        f.update_whole_extent()
        f.update()
        self.data_changed = True

    def _update_sample_rate(self):
        f = self.filter
        f.sample_rate = (self.x_ratio, self.y_ratio, self.z_ratio)
        f.update_whole_extent()
        f.update()
        self.data_changed = True

    def _filter_changed(self, old, new):
        if old is not None:
            old.on_trait_change(self.render, remove=True)
        new.on_trait_change(self.render)
Example #16
0
class InfoPanel(HasTraits):
    cursor = Property(Tuple)
    cursor_ras = Property(Tuple)
    cursor_tkr = Property(Tuple)
    cursor_intensity = Float

    mouse = Tuple((0.,0.,0.))
    mouse_ras = Tuple((0.,0.,0.))
    mouse_tkr = Tuple((0.,0.,0.))
    mouse_intensity = Float

    cursor_csvlist = List(Float)
    cursor_ras_csvlist = List(Float)
    cursor_tkr_csvlist = List(Float)

    pin_tolerance = Float(7.5)
    currently_showing_list = List(Instance(NullInstanceHolder))
    currently_showing = Instance(NullInstanceHolder)

    add_electrode_button = Button('Make new elec here')
    confirm_movepin_internal_button = Button('Move elec here')
    confirm_movepin_postproc_button = Button('Move postproc')
    track_cursor_button = Button('Track cursor')
    reset_image_button = Button('Center image')

    minimum_contrast = Float( -2000 )
    maximum_contrast = Float( 5000 )

    traits_view = View(
        VGroup(
            Item('currently_showing', 
                editor=InstanceEditor(name='currently_showing_list'),
                style='custom'),
            Spring(),
            HGroup(
            Item('minimum_contrast', editor=TextEditor(enter_set=True,
                auto_set=False, evaluate=float)),
            Item('maximum_contrast', editor=TextEditor(enter_set=True,
                auto_set=False, evaluate=float)),
            ),
            Spring(),
            HGroup(
            Item('add_electrode_button', show_label=False),
            Item('track_cursor_button', show_label=False),
            Item('reset_image_button', show_label=False),
            ),
            HGroup(
            Item('confirm_movepin_internal_button', show_label=False),
            Item('confirm_movepin_postproc_button', show_label=False),
            ),
            Item('pin_tolerance'),
            Spring(),
            Item(name='cursor_csvlist', style='text', label='cursor',
                editor=CSVListEditor(enter_set=True, auto_set=False)),
            Item(name='cursor_ras_csvlist', style='text', label='cursor RAS',
                editor=CSVListEditor(enter_set=True, auto_set=False)),
            Item(name='cursor_tkr_csvlist', style='text', label='cursor tkr',
                editor=CSVListEditor(enter_set=True, auto_set=False)),
            Item(name='cursor_intensity', style='readonly',
                label='cursor intensity'),
            Item(name='mouse', style='readonly', label='mouse'),
            Item(name='mouse_ras', style='readonly', label='mouse RAS'),
            Item(name='mouse_tkr', style='readonly', label='mouse tkr'),
            Item(name='mouse_intensity', style='readonly',
                label='mouse intensity'),
        ),
        height=400, width=400,
        title='ilumbumbargu',
    )

    def _get_cursor(self):
        return tuple(self.cursor_csvlist)
    def _set_cursor(self, newval):
        self.cursor_csvlist = list(newval)
    def _get_cursor_ras(self):
        return tuple(self.cursor_ras_csvlist)
    def _set_cursor_ras(self, newval):
        self.cursor_ras_csvlist = list(newval)
    def _get_cursor_tkr(self):
        return tuple(self.cursor_tkr_csvlist)
    def _set_cursor_tkr(self, newval):
        self.cursor_tkr_csvlist = list(newval)
Example #17
0
 def trait_view(self, name=None, view_elements=None):
     if name is None or name=='full':
         return View(
           VGroup( 
             HSplit(
                   VSplit(
                     Item('function_search',
                          editor = InstanceEditor(view=function_search_view),
                          label      = 'Search',
                          id         = 'search',
                          style      = 'custom',
                          dock       = 'horizontal',
                          show_label = False,                      
                     ),
                     Item('html_window',
                          style='custom',
                          show_label=False,
                          springy= True,
                          resizable=True,
                     ),
                     id='search_help_view'
                   ),      
                 VSplit(
                     Item( 'object.project.active_experiment.canvas',
                           label      = 'Canvas',
                           id         = 'canvas',
                           # FIXME:  need a new way to control the canvas
                           # not using BlockEditor
                           editor     = BlockEditor(),
                           dock       = 'horizontal',
                           show_label = False
                     ),
                     Item( 'object.project.active_experiment.exec_model.code',
                           label      = 'Code',
                           id         = 'code',
                           editor     = CodeEditor(dim_lines = 'dim_lines',
                                                   dim_color = 'dim_color',
                                                   squiggle_lines = 'squiggle_lines'),
                           dock       = 'horizontal',
                           show_label = False
                     ),
                 ),
                 Item( 'context_viewer',
                       label = 'Context',
                       id = 'context_table',
                       editor = InstanceEditor(),
                       style = 'custom',
                       dock = 'horizontal',
                       show_label = False,
                 ),
                 id='panel_split',
             ),
             Item( 'status',
                   style      = 'readonly',
                   show_label = False,
                   resizable  = False 
             ),
           ),
           title     = 'Block Canvas',
           menubar   = BlockApplicationMenuBar,
           width     = 1024,
           height    = 768,
           id        = 'blockcanvas.app.application',
           resizable = True,
           handler   = BlockApplicationViewHandler(model=self),
           key_bindings = KeyBindings(
             KeyBinding(binding1='F5', method_name='_on_execute'),
             ),
         )
     elif name == 'simple':
         return View( 
                     HSplit(
                             VSplit(
                                     Item('function_search',
                                          editor = InstanceEditor(view=function_search_view),
                                          label      = 'Search',
                                          id         = 'search',
                                          style      = 'custom',
                                          dock       = 'horizontal',
                                          show_label = False),
                                     Item('html_window',
                                          style='custom',
                                          show_label=False,
                                          springy= True,
                                          resizable=True),
                                     id='search_help_view'
                                     ),      
                               Item( 'object.project.active_experiment.canvas',
                                           label      = 'Canvas',
                                           id         = 'canvas',
                                           # FIXME:  need a new way to control the canvas
                                           # not using BlockEditor
                                           editor     = BlockEditor(),
                                           dock       = 'horizontal',
                                           show_label = False),
                             id='panel_split'),
                   title     = 'Block Canvas - Simple View',
                   menubar   = BlockApplicationMenuBar,
                   width     = 800,
                   height    = 600,
                   id        = 'blockcanvas.app.application.simple',
                   resizable = True,
                   handler   = BlockApplicationViewHandler(model=self),
                   key_bindings = KeyBindings(
                                              KeyBinding(binding1='F5', method_name='_on_execute'),
                                              )
                 )
Example #18
0
class TwoDimensionalPanel(Handler):
    images = Dict # Str -> Tuple(imgd, affine, tkr_affine)

    #original_current_image = Any #np.ndarray XxYxZ
    current_image = Any # np.ndarray XxYxZ
    #original_image is not allowed to be altered by thresholding.
    #current_image may be reset by copying original whenever threshold change
    current_affine = Any
    current_tkr_affine = Any

    xy_plane = Instance(Plot)
    xz_plane = Instance(Plot)
    yz_plane = Instance(Plot)
    
    pins = Dict # Str -> (Str -> 3-Tuple)
    pin_tolerance = DelegatesTo('info_panel')
    minimum_contrast = DelegatesTo('info_panel')
    maximum_contrast = DelegatesTo('info_panel')

    current_pin = Str(None)
    confirm_movepin_postproc_button = DelegatesTo('info_panel')
    confirm_movepin_internal_button = DelegatesTo('info_panel')
    add_electrode_button = DelegatesTo('info_panel')
    move_electrode_internally_event = Event
    move_electrode_postprocessing_event = Event
    add_electrode_event = Event
    track_cursor_button = DelegatesTo('info_panel')
    track_cursor_event = Event
    untrack_cursor_event = Event
    panel2d_closed_event = Event
    reset_image_button = DelegatesTo('info_panel')
    
    info_panel = Instance(InfoPanel, ())

    currently_showing_list = DelegatesTo('info_panel')
    currently_showing = DelegatesTo('info_panel')

    #later we will rename cursor to "coord"

    cursor = Tuple # 3-tuple

    null = Any # None

    _finished_plotting = Bool(False)

    traits_view = View(
        Group(
        HGroup(
            Item(name='xz_plane', editor=ComponentEditor(),
                height=400, width=400, show_label=False, resizable=True),
            Item(name='yz_plane', editor=ComponentEditor(),
                height=400, width=400, show_label=False, resizable=True),
        ),
        HGroup(
            Item(name='xy_plane', editor=ComponentEditor(),
                height=400, width=400, show_label=False, resizable=True),
            Item(name='info_panel', 
                    editor=InstanceEditor(), 
                style='custom',
            #Item(name='null', editor=NullEditor(),
                height=400, width=400, show_label=False, resizable=True),
        ),
        ),
        title='Contact 867-5309 for blobfish sales',
    )

    def map_cursor(self, cursor, affine, invert=False):
        x,y,z = cursor
        aff_to_use = np.linalg.inv(affine) if invert else affine
        mcursor, = apply_affine([cursor], aff_to_use)
        return tuple(map(lambda x: truncate(x, 2), mcursor))

    #def cut_data(self, data, mcursor):
    def cut_data(self, ndata, mcursor):
        xm,ym,zm = [int(np.round(c)) for c in mcursor]
        #xm, ym, zm = mcursor
        #yz_cut = np.rot90(data[xm,:,:].T)
        #xz_cut = np.rot90(data[:,ym,:].T)
        #xy_cut = np.rot90(data[:,:,zm].T)

        #ndata = data.copy()
        ndata[ndata < self.minimum_contrast] = self.minimum_contrast
        ndata[ndata > self.maximum_contrast] = self.maximum_contrast

        yz_cut = ndata[xm,:,:].T
        xz_cut = ndata[:,ym,:].T
        xy_cut = ndata[:,:,zm].T
        return xy_cut, xz_cut, yz_cut

    def load_img(self, imgf, reorient2std=False, image_name=None):
        self._finished_plotting = False

        img = nib.load(imgf)

        aff = np.dot(
            get_orig2std(imgf) if reorient2std else np.eye(4),
            img.get_affine())
        tkr_aff = np.dot(
            reorient_orig2std_tkr_mat if reorient2std else np.eye(4),
            get_vox2rasxfm(imgf, stem='vox2ras-tkr'))

        #from nilearn.image.resampling import reorder_img

        #img = reorder_img(uimg, resample='continuous')

        xsz, ysz, zsz = img.shape

        #print 'image coordinate transform', img.get_affine()

        imgd = img.get_data()
        if reorient2std:
            imgd = np.swapaxes(imgd, 1, 2)[:,:,::-1]

        #print 'image size', imgd.shape

        if image_name is None:
            from utils import gensym
            image_name = 'image%s'%gensym()

        self.images[image_name] = (imgd, aff, tkr_aff)
        self.currently_showing_list.append(
            NullInstanceHolder(name=image_name))
        self.currently_showing = self.currently_showing_list[-1]

        self.pins[image_name] = {}
        #self.current_pin = image_name

        self.show_image(image_name)

    @on_trait_change('currently_showing, minimum_contrast, maximum_contrast')
    def switch_image(self):
        self.show_image(self.currently_showing.name)

    @on_trait_change('reset_image_button')
    def center_image(self):
        x, y, z = self.current_image.shape

        for plane, (r,c) in zip((self.xy_plane, self.xz_plane, self.yz_plane),
                                ((x,y), (x,z), (y,z))):
            plane.index_mapper.range.low = 0
            plane.value_mapper.range.low = 0
            plane.index_mapper.range.high = r
            plane.value_mapper.range.high = c
    
    def show_image(self, image_name, xyz=None):
        # XYZ is given in pixel coordinates
        cur_img_t, self.current_affine, self.current_tkr_affine = (
            self.images[image_name])

        self.current_image = cur_img_t.copy()

        if xyz is None:
            xyz = tuple(np.array(self.current_image.shape) // 2)
        self.cursor = x,y,z = xyz
        xy_cut, xz_cut, yz_cut = self.cut_data(self.current_image, 
            self.cursor)

        xsz, ysz, zsz = self.current_image.shape

        xy_plotdata = ArrayPlotData()
        xy_plotdata.set_data('imagedata', xy_cut)
        xy_plotdata.set_data('cursor_x', np.array((x,)))
        xy_plotdata.set_data('cursor_y', np.array((y,)))

        xz_plotdata = ArrayPlotData()
        xz_plotdata.set_data('imagedata', xz_cut)
        xz_plotdata.set_data('cursor_x', np.array((x,)))
        xz_plotdata.set_data('cursor_z', np.array((z,)))

        yz_plotdata = ArrayPlotData()
        yz_plotdata.set_data('imagedata', yz_cut)
        yz_plotdata.set_data('cursor_y', np.array((y,)))
        yz_plotdata.set_data('cursor_z', np.array((z,)))

        self.xy_plane = Plot(xy_plotdata, bgcolor='black',
            #aspect_ratio=xsz/ysz)
            )
        self.xz_plane = Plot(xz_plotdata, bgcolor='black',
            #aspect_ratio=xsz/zsz)
            )
        self.yz_plane = Plot(yz_plotdata, bgcolor='black',
            #aspect_ratio=ysz/zsz)
            )

        self.xy_plane.img_plot('imagedata',name='brain',colormap=bone_cmap)
        self.xz_plane.img_plot('imagedata',name='brain',colormap=bone_cmap)
        self.yz_plane.img_plot('imagedata',name='brain',colormap=bone_cmap)

        self.xy_plane.plot(('cursor_x','cursor_y'), type='scatter', 
            color='red', marker='plus', size=3, name='cursor')
        self.xz_plane.plot(('cursor_x','cursor_z'), type='scatter',
            color='red', marker='plus', size=3, name='cursor')
        self.yz_plane.plot(('cursor_y','cursor_z'), type='scatter',
            color='red', marker='plus', size=3, name='cursor')

        self.xy_plane.tools.append(Click2DPanelTool(self, 'xy'))
        self.xz_plane.tools.append(Click2DPanelTool(self, 'xz'))
        self.yz_plane.tools.append(Click2DPanelTool(self, 'yz'))

        self.xy_plane.tools.append(ZoomTool( self.xy_plane ))
        self.xz_plane.tools.append(ZoomTool( self.xz_plane ))
        self.yz_plane.tools.append(ZoomTool( self.yz_plane ))

        #self.xy_plane.tools.append(PanTool( self.xy_plane ))
        #self.xz_plane.tools.append(PanTool( self.xz_plane ))
        #self.yz_plane.tools.append(PanTool( self.yz_plane ))

        self.info_panel.cursor = self.cursor
        self.info_panel.cursor_ras = self.map_cursor(self.cursor, 
            self.current_affine)
        self.info_panel.cursor_tkr = self.map_cursor(self.cursor,
            self.current_tkr_affine)
        self.info_panel.cursor_intensity = self.current_image[x,y,z]

        self._finished_plotting = True

        if image_name in self.pins:
            for pin in self.pins[image_name]:
                px, py, pz, pcolor = self.pins[image_name][pin]
                self.drop_pin(px,py,pz, name=pin, color=pcolor)

    def cursor_outside_image_dimensions(self, cursor, image=None):
        if image is None:
            image = self.current_image

        x, y, z = cursor

        x_sz, y_sz, z_sz = image.shape

        if not 0 <= x < x_sz:
            return True
        if not 0 <= y < y_sz:
            return True
        if not 0 <= z < z_sz:
            return True
        
        return False

    def move_cursor(self, x, y, z, suppress_cursor=False, suppress_ras=False,
            suppress_tkr=False):

        #it doesnt seem necessary for the instance variable cursor to exist
        #at all but this code isn't broken
        cursor = x,y,z

        if self.cursor_outside_image_dimensions(cursor):
            print ('Cursor %.2f %.2f %.2f outside image dimensions, doing '
                'nothing'%(x,y,z))
            return

        self.cursor = cursor

        xy_cut, xz_cut, yz_cut = self.cut_data(self.current_image, self.cursor)

        print 'clicked on point %.2f %.2f %.2f'%(x,y,z)

        self.xy_plane.data.set_data('imagedata', xy_cut)
        self.xz_plane.data.set_data('imagedata', xz_cut)
        self.yz_plane.data.set_data('imagedata', yz_cut)

        self.xy_plane.data.set_data('cursor_x', np.array((x,)))
        self.xy_plane.data.set_data('cursor_y', np.array((y,)))

        self.xz_plane.data.set_data('cursor_x', np.array((x,)))
        self.xz_plane.data.set_data('cursor_z', np.array((z,)))

        self.yz_plane.data.set_data('cursor_y', np.array((y,)))
        self.yz_plane.data.set_data('cursor_z', np.array((z,)))

        if not suppress_cursor:
            self.info_panel.cursor = tuple(
                map(lambda x:truncate(x, 2), self.cursor))
        if not suppress_ras:
            self.info_panel.cursor_ras = self.map_cursor(self.cursor,
                self.current_affine)
        if not suppress_tkr:
            self.info_panel.cursor_tkr = self.map_cursor(self.cursor,
                self.current_tkr_affine)
        self.info_panel.cursor_intensity = truncate(
            self.current_image[x,y,z],3)

        image_name = self.currently_showing.name

        if image_name in self.pins:
            for pin in self.pins[image_name]:
                px, py, pz, pcolor = self.pins[image_name][pin]
                self.drop_pin(px,py,pz, name=pin, color=pcolor)

        self.untrack_cursor_event = True

    def redraw(self):
        self.xz_plane.request_redraw()
        self.yz_plane.request_redraw()
        self.xy_plane.request_redraw()

    def drop_pin(self, x, y, z, name='pin', color='yellow', 
            image_name=None, ras_coords=False, alter_current_pin=True):
        '''
        XYZ is given in pixel space
        '''
        if ras_coords:
            #affine might not necessarily be from image currently on display
            _,_,affine = self.images[image_name]

            ras_pin = self.map_cursor((x,y,z), affine, invert=True)
            x,y,z = ras_pin

        if image_name is None:
            image_name = self.currently_showing.name

        cx, cy, cz = self.cursor

        tolerance = self.pin_tolerance

        if image_name == self.currently_showing.name:
            self.xy_plane.data.set_data('%s_x'%name, 
                np.array((x,) if np.abs(z - cz) < tolerance else ()))
            self.xy_plane.data.set_data('%s_y'%name, 
                np.array((y,) if np.abs(z - cz) < tolerance else ()))
            
            self.xz_plane.data.set_data('%s_x'%name, 
                np.array((x,) if np.abs(y - cy) < tolerance else ()))
            self.xz_plane.data.set_data('%s_z'%name, 
                np.array((z,) if np.abs(y - cy) < tolerance else ()))
        
            self.yz_plane.data.set_data('%s_y'%name, 
                np.array((y,) if np.abs(x - cx) < tolerance else ()))
            self.yz_plane.data.set_data('%s_z'%name, 
                np.array((z,) if np.abs(x - cx) < tolerance else ()))

            #if name not in self.pins[image_name]:
            if name not in self.xy_plane.plots:
                self.xy_plane.plot(('%s_x'%name,'%s_y'%name), type='scatter', 
                    color=color, marker='dot', size=3, name=name)
                self.xz_plane.plot(('%s_x'%name,'%s_z'%name), type='scatter',
                    color=color, marker='dot', size=3, name=name)
                self.yz_plane.plot(('%s_y'%name,'%s_z'%name), type='scatter',
                    color=color, marker='dot', size=3, name=name)

            self.redraw()

        self.pins[image_name][name] = (x,y,z,color)

        if alter_current_pin:
            self.current_pin = name

    def move_mouse(self, x, y, z):
        mouse = (x,y,z)

        if self.cursor_outside_image_dimensions(mouse):
            return

        self.info_panel.mouse = tuple(map(lambda x:truncate(x, 2), mouse))
        self.info_panel.mouse_ras = self.map_cursor(mouse,
            self.current_affine)
        self.info_panel.mouse_tkr = self.map_cursor(mouse, 
            self.current_tkr_affine)
        self.info_panel.mouse_intensity = truncate(
            self.current_image[x,y,z], 3)

    def _confirm_movepin_internal_button_fired(self):
        self.move_electrode_internally_event = True

    def _confirm_movepin_postproc_button_fired(self):
        self.move_electrode_postprocessing_event = True 

    def _add_electrode_button_fired(self):
        self.add_electrode_event = True

    def _track_cursor_button_fired(self):
        self.track_cursor_event = True

    def closed(self, info, is_ok):
        self.panel2d_closed_event = True

    #because these calls all call map_cursor, which changes the listener
    #variables they end up infinite looping.

    #to solve this we manage _finished_plotting manually
    #so that move_cursor is only called once when any listener is triggered
    @on_trait_change('info_panel:cursor_csvlist')
    def _listen_cursor(self):
        if self._finished_plotting and len(self.info_panel.cursor) == 3:
            self._finished_plotting = False
            x,y,z = self.info_panel.cursor
            self.move_cursor(x,y,z, suppress_cursor=True)
            self._finished_plotting = True

    @on_trait_change('info_panel:cursor_ras_csvlist')
    def _listen_cursor_ras(self):
        if self._finished_plotting and len(self.info_panel.cursor_ras) == 3:
            self._finished_plotting = False
            x,y,z = self.map_cursor(self.info_panel.cursor_ras,
                self.current_affine, invert=True)
            self.move_cursor(x,y,z, suppress_ras=True)
            self._finished_plotting = True

    @on_trait_change('info_panel:cursor_tkr_csvlist')
    def _listen_cursor_tkr(self):
        if self._finished_plotting and len(self.info_panel.cursor_tkr) == 3:
            self._finished_plotting = False
            x,y,z = self.map_cursor(self.info_panel.cursor_tkr,
                self.current_tkr_affine, invert=True)
            self.move_cursor(x,y,z, suppress_tkr=True)
            self._finished_plotting = True
Example #19
0
save_cfg = Action(name="Save Config.", action="do_save_cfg")
load_cfg = Action(name="Load Config.", action="do_load_cfg")

# Main window layout definition.
traits_view = View(
    Group(
        VGroup(
            HGroup(
                VGroup(
                    HGroup(  # Simulation Control
                        VGroup(
                            Item(
                                name="bit_rate",
                                label="Bit Rate (Gbps)",
                                tooltip="bit rate",
                                show_label=True,
                                enabled_when="True",
                                editor=TextEditor(auto_set=False,
                                                  enter_set=True,
                                                  evaluate=float),
                            ),
                            Item(
                                name="nbits",
                                label="Nbits",
                                tooltip="# of bits to run",
                                editor=TextEditor(auto_set=False,
                                                  enter_set=True,
                                                  evaluate=int),
                            ),
                            Item(
                                name="nspb",
                                label="Nspb",
Example #20
0
def queue_factory_item(name, **kw):
    return Item(queue_factory_name(name), **kw)
Example #21
0
 def traits_view(self):
     return View(
         Item('host'),
         Item('port'))
Example #22
0
 def traits_view(self):
     v = View(
         UItem('analysis_type', style='readonly'),
         Item('isotopes',
              editor=TabularEditor(adapter=AnalysisHealthAdapter())))
     return v
Example #23
0
    @on_trait_change('set')
    def _on_set_change(self, obj, name, old, new):
        if new == 'Nasion':
            self.current_pos_mm = self.nasion * 1000
            self.headview.front = True
        elif new == 'LPA':
            self.current_pos_mm = self.lpa * 1000
            self.headview.left = True
        elif new == 'RPA':
            self.current_pos_mm = self.rpa * 1000
            self.headview.right = True


# FiducialsPanel view that allows manipulating all coordinates numerically
view2 = View(VGroup(Item('fid_file', label='Fiducials File'),
                    Item('fid_fname', show_label=False, style='readonly'),
                    Item('set', style='custom'), 'lpa', 'nasion', 'rpa',
                    HGroup(Item('save', enabled_when='can_save'),
                           Item('save_as', enabled_when='can_save_as'),
                           Item('reset_fid', enabled_when='can_reset'),
                           show_labels=False),
                    enabled_when="locked==False"))


class FiducialsFrame(HasTraits):
    """GUI for interpolating between two KIT marker files.

    Parameters
    ----------
    subject : None | str
Example #24
0
def run_factory_item(name, **kw):
    return Item(run_factory_name(name), **kw)
Example #25
0
class FiducialsFrame(HasTraits):
    """GUI for interpolating between two KIT marker files.

    Parameters
    ----------
    subject : None | str
        Set the subject which is initially selected.
    subjects_dir : None | str
        Override the SUBJECTS_DIR environment variable.
    """

    model = Instance(MRIHeadWithFiducialsModel, ())

    scene = Instance(MlabSceneModel, ())
    headview = Instance(HeadViewController)

    spanel = Instance(SubjectSelectorPanel)
    panel = Instance(FiducialsPanel)

    mri_obj = Instance(SurfaceObject)
    point_scale = float(defaults['mri_fid_scale'])
    lpa_obj = Instance(PointObject)
    nasion_obj = Instance(PointObject)
    rpa_obj = Instance(PointObject)

    def _headview_default(self):
        return HeadViewController(scene=self.scene, system='RAS')

    def _panel_default(self):
        panel = FiducialsPanel(model=self.model, headview=self.headview)
        panel.trait_view('view', view2)
        return panel

    def _spanel_default(self):
        return SubjectSelectorPanel(model=self.model.subject_source)

    view = View(HGroup(Item('scene',
                            editor=SceneEditor(scene_class=MayaviScene),
                            dock='vertical'),
                       VGroup(headview_borders,
                              VGroup(Item('spanel', style='custom'),
                                     label="Subject", show_border=True,
                                     show_labels=False),
                              VGroup(Item('panel', style="custom"),
                                     label="Fiducials", show_border=True,
                                     show_labels=False),
                              show_labels=False),
                       show_labels=False),
                resizable=True,
                buttons=NoButtons)

    def __init__(self, subject=None, subjects_dir=None,
                 **kwargs):  # noqa: D102
        super(FiducialsFrame, self).__init__(**kwargs)

        subjects_dir = get_subjects_dir(subjects_dir)
        if subjects_dir is not None:
            self.spanel.subjects_dir = subjects_dir

        if subject is not None:
            if subject in self.spanel.subjects:
                self.spanel.subject = subject

    @on_trait_change('scene.activated')
    def _init_plot(self):
        _toggle_mlab_render(self, False)

        lpa_color = defaults['lpa_color']
        nasion_color = defaults['nasion_color']
        rpa_color = defaults['rpa_color']

        # bem
        color = defaults['mri_color']
        self.mri_obj = SurfaceObject(points=self.model.points, color=color,
                                     tri=self.model.tris, scene=self.scene)
        self.model.on_trait_change(self._on_mri_src_change, 'tris')
        self.panel.hsp_obj = self.mri_obj

        # fiducials
        self.lpa_obj = PointObject(scene=self.scene, color=lpa_color,
                                   has_norm=True,
                                   point_scale=self.point_scale)
        self.panel.sync_trait('lpa', self.lpa_obj, 'points', mutual=False)
        self.sync_trait('point_scale', self.lpa_obj, mutual=False)

        self.nasion_obj = PointObject(scene=self.scene, color=nasion_color,
                                      has_norm=True,
                                      point_scale=self.point_scale)
        self.panel.sync_trait('nasion', self.nasion_obj, 'points',
                              mutual=False)
        self.sync_trait('point_scale', self.nasion_obj, mutual=False)

        self.rpa_obj = PointObject(scene=self.scene, color=rpa_color,
                                   has_norm=True,
                                   point_scale=self.point_scale)
        self.panel.sync_trait('rpa', self.rpa_obj, 'points', mutual=False)
        self.sync_trait('point_scale', self.rpa_obj, mutual=False)

        self.headview.left = True
        _toggle_mlab_render(self, True)

        # picker
        self.scene.mayavi_scene.on_mouse_pick(self.panel._on_pick, type='cell')

    def _on_mri_src_change(self):
        if (not np.any(self.model.points)) or (not np.any(self.model.tris)):
            self.mri_obj.clear()
            return

        self.mri_obj.points = self.model.points
        self.mri_obj.tri = self.model.tris
        self.mri_obj.plot()
Example #26
0
    def traits_view(self):
        ss = '''
QLineEdit {font-size: 14px}
QGroupBox {background-color: qlineargradient(x1: 0, y1: 0, x2: 0, y2: 1,
                                      stop: 0 #E0E0E0, stop: 1 #FFFFFF);
           border: 2px solid gray;
           border-radius: 5px;
           margin-top: 1ex; /* leave space at the top for the title */
           font-size: 14px;
           font-weight: bold;}
QGroupBox::title {subcontrol-origin: margin;
                  subcontrol-position: top left; /* position at the top center */
                  padding: 2 3px;}
QComboBox {font-size: 14px}
QLabel {font-size: 14px}
QToolBox::tab {font-size: 15px}
QToolTip {font-size: 14px}
'''

        add_button = icon_button_editor(
            'add_button',
            'add',
            # enabled_when='ok_add',
            tooltip='Add run')

        save_button = icon_button_editor('save_button',
                                         'disk',
                                         tooltip='Save queue to file')

        edit_button = icon_button_editor('edit_mode_button',
                                         'table_edit',
                                         enabled_when='edit_enabled',
                                         tooltip='Toggle edit mode')

        clear_button = icon_button_editor(
            'clear_button',
            'table_row_delete',
            tooltip='Clear all runs added using "frequency"')

        email_grp = VGroup(HGroup(
            queue_factory_item('use_email',
                               label='Use Email',
                               tooltip='Send email notifications'),
            queue_factory_item('use_group_email',
                               tooltip='Email a group of users',
                               label='Email Group'),
            icon_button_editor(queue_factory_name('edit_emails'),
                               'cog',
                               tooltip='Edit user group')),
                           Item(queue_factory_name('email')),
                           show_border=True,
                           label='Email')

        user_grp = HGroup(UItem(
            queue_factory_name('username'),
            show_label=False,
            editor=ComboboxEditor(name=queue_factory_name('usernames'))),
                          icon_button_editor(queue_factory_name('edit_user'),
                                             'database_edit'),
                          show_border=True,
                          label='User')
        ms_ed_grp = VGroup(HGroup(
            queue_factory_item(
                'mass_spectrometer',
                show_label=False,
                editor=EnumEditor(
                    name=queue_factory_name('mass_spectrometers'))),
            queue_factory_item(
                'extract_device',
                show_label=False,
                editor=EnumEditor(name=queue_factory_name('extract_devices'))),
            queue_factory_item(
                'load_name',
                width=150,
                label='Load',
                editor=ComboboxEditor(name=queue_factory_name('load_names'))),
            icon_button_editor(
                'generate_queue_button',
                'brick-go',
                tooltip='Generate a experiment queue from the selected load',
                enabled_when='load_name'),
            icon_button_editor(
                'edit_queue_config_button',
                'cog',
                tooltip='Configure experiment queue generation')),
                           HGroup(
                               queue_factory_item(
                                   'queue_conditionals_name',
                                   label='Queue Conditionals',
                                   editor=EnumEditor(name=queue_factory_name(
                                       'available_conditionals')))),
                           label='Spectrometer/Extract Device',
                           show_border=True)
        delay_grp = VGroup(queue_factory_item('delay_before_analyses'),
                           queue_factory_item('delay_between_analyses'),
                           show_border=True,
                           label='Delays')
        queue_grp = VGroup(user_grp,
                           email_grp,
                           ms_ed_grp,
                           delay_grp,
                           label='Queue')

        button_bar = HGroup(
            save_button, add_button, clear_button, edit_button,
            CustomLabel(run_factory_name('edit_mode_label'),
                        color='red',
                        width=40), spring)
        button_bar2 = HGroup(
            Item('auto_increment_id', label='Auto Increment L#'),
            Item('auto_increment_position', label='Position'),
        )
        edit_grp = VFold(
            queue_grp,
            VGroup(self._get_info_group(),
                   self._get_extract_group(),
                   enabled_when=queue_factory_name('ok_make'),
                   label='General'), self._get_script_group(),
            self._get_truncate_group())

        v = View(
            VGroup(
                button_bar,
                button_bar2,
                UItem('pane.info_label', style='readonly'),
                edit_grp,

                # lower_button_bar,
                style_sheet=ss),
            kind='live',
            width=225)
        return v
Example #27
0
    def traits_view(self):
        ctrl_grp = VGroup(Item('path', show_label=False),
                        Item('highlight_bands', editor=ListEditor(mutable=False,
                                                                 style='custom', editor=InstanceEditor()))
                        )
        v = View(
               ctrl_grp,
               Item('container', show_label=False,
                       editor=ComponentEditor()),
#
                 title='Color Inspector',
                 resizable=True,
                 height=800,
                 width=900
                 )
        return v
#    def traits_view(self):
#        lgrp = VGroup(Item('low'),
#                      Item('low', show_label=False, editor=RangeEditor(mode='slider', low=0, high_name='high')))
#        hgrp = VGroup(Item('high'),
#                      Item('high', show_label=False, editor=RangeEditor(mode='slider', low_name='low', high=255)))
#        savegrp = HGroup(Item('save_button', show_label=False),
#                         Item('save_mode', show_label=False))
#        ctrlgrp = VGroup(
#                         Item('path', show_label=False),
#                         HGroup(Item('use_threshold'), Item('contrast_equalize'),
#                                HGroup(Item('contrast_low'), Item('contrast_high'), enabled_when='contrast_equalize'),
#                                Item('histogram_equalize')
#                                ),
#                         HGroup(Item('highlight'), Item('highlight_threshold')),
#                         HGroup(spring,
#                                lgrp,
#                                hgrp,
#                                VGroup(savegrp,
#                                       Item('calc_area_value', label='Calc. Area For.',
#                                                     tooltip='Calculate %area for all pixels with this value'
#                                                     ),
#                                       Item('calc_area_threshold', label='Threshold +/- px',
#                                            tooltip='bandwidth= calc_value-threshold to calc_value+threshold'
#                                            )
#
#                                       )
#                                ),
#                         HGroup(spring, Item('area', style='readonly', width= -200)),
#                         HGroup(
#                                Item('colormap_name_1', show_label=False,
#                                      editor=EnumEditor(values=color_map_name_dict.keys())),
#                                spring,
#                                Item('colormap_name_2', show_label=False,
#                                     editor=EnumEditor(values=color_map_name_dict.keys()))),
#                       )
#        v = View(ctrlgrp,
#                 Item('container', show_label=False,
#                       editor=ComponentEditor()),
#
#                 title='Color Inspector',
#                 resizable=True,
#                 height=800,
#                 width=900
#
#                 )
        return v
Example #28
0
class DCBTestModel(BMCSModel, Vis2D):

    #=========================================================================
    # Tree node attributes
    #=========================================================================
    node_name = 'double cantilever beam simulation'

    tree_node_list = List([])

    def _tree_node_list_default(self):

        return [
            self.tline,
            self.mats_eval,
            self.cross_section,
            self.geometry,
        ]

    def _update_node_list(self):
        self.tree_node_list = [
            self.tline,
            self.mats_eval,
            self.cross_section,
            self.geometry,
        ]

    #=========================================================================
    # Interactive control of the time loop
    #=========================================================================
    def init(self):
        self.tloop.init()

    def eval(self):
        return self.tloop.eval()

    def pause(self):
        self.tloop.paused = True

    def stop(self):
        self.tloop.restart = True

    #=========================================================================
    # Test setup parameters
    #=========================================================================
    loading_scenario = Instance(LoadingScenario)

    def _loading_scenario_default(self):
        return LoadingScenario()

    cross_section = Instance(CrossSection)

    def _cross_section_default(self):
        return CrossSection()

    geometry = Instance(Geometry)

    def _geometry_default(self):
        return Geometry()

    #=========================================================================
    # Discretization
    #=========================================================================
    n_e_x = Int(2, auto_set=False, enter_set=True)
    n_e_y = Int(8, auto_set=False, enter_set=True)
    n_e_z = Int(1, auto_set=False, enter_set=True)

    w_max = Float(0.01, BC=True, auto_set=False, enter_set=True)

    #=========================================================================
    # Material model
    #=========================================================================
    mats_eval_type = Trait('scalar damage', {
        'elastic': MATS3DElastic,
        'scalar damage': MATS3DScalarDamage,
        'microplane damage (eeq)': MATS3DMplDamageEEQ,
        'microplane damage (odf)': MATS3DMplDamageODF,
    },
                           MAT=True)

    @on_trait_change('mats_eval_type')
    def _set_mats_eval(self):
        self.mats_eval = self.mats_eval_type_()

    @on_trait_change('BC,MAT,MESH')
    def reset_node_list(self):
        self._update_node_list()

    mats_eval = Instance(IMATSEval, MAT=True)
    '''Material model'''

    def _mats_eval_default(self):
        return self.mats_eval_type_()

    material = Property

    def _get_material(self):
        return self.mats_eval

    #=========================================================================
    # Finite element type
    #=========================================================================
    fets_eval = Property(Instance(FETS3D8H), depends_on='CS,MAT')
    '''Finite element time stepper implementing the corrector
    predictor operators at the element level'''

    @cached_property
    def _get_fets_eval(self):
        return FETS3D8H()

    bcond_mngr = Property(Instance(BCondMngr), depends_on='CS,BC,MESH')
    '''Boundary condition manager
    '''

    @cached_property
    def _get_bcond_mngr(self):
        bc_list = [
            self.fixed_left_x,
            self.fixed_top_y,
            # self.link_right_cs,
            self.uniform_control_bc,
        ]  # + self.link_right_x

        return BCondMngr(bcond_list=bc_list)

    fixed_left_x = Property(depends_on='CS, BC,GEO,MESH')
    '''Foxed boundary condition'''

    @cached_property
    def _get_fixed_left_x(self):
        a_L = self.geometry.a / self.geometry.L
        n_a = int(a_L * self.n_e_y)
        print('n_a', n_a)
        return BCSlice(slice=self.fe_grid[0, :, :, 0, :, :],
                       var='u',
                       dims=[0],
                       value=0)

    fixed_top_y = Property(depends_on='CS, BC,GEO,MESH')
    '''Foxed boundary condition'''

    @cached_property
    def _get_fixed_top_y(self):
        return BCSlice(slice=self.fe_grid[:, -1, :, :, -1, :],
                       var='u',
                       dims=[1, 2],
                       value=0)

    link_right_cs = Property(depends_on='CS,BC,GEO,MESH')
    '''Foxed boundary condition'''

    @cached_property
    def _get_link_right_cs(self):
        f_dof = self.fe_grid[-1, :, -1, -1, :, -1]
        b_dof = self.fe_grid[-1, :, 0, -1, :, 0]
        return BCSlice(name='link_cs',
                       slice=f_dof,
                       link_slice=b_dof,
                       dims=[0],
                       link_coeffs=[1],
                       value=0)

    link_right_x = Property(depends_on='CS,BC,GEO,MESH')
    '''Foxed boundary condition'''

    @cached_property
    def _get_link_right_x(self):
        top = self.fe_grid[-1, -1, 0, -1, -1, 0]
        bot = self.fe_grid[-1, 0, 0, -1, 0, 0]
        linked = self.fe_grid[-1, 1:, 0, -1, 0, 0]

        Ty = top.dof_X[0, 0, 1]
        By = bot.dof_X[0, 0, 1]

        Ly = linked.dof_X[:, :, 1].flatten()

        H = Ty - By
        link_ratios = Ly / H
        top_dof = top.dofs[0, 0, 0]
        bot_dof = bot.dofs[0, 0, 0]
        linked_dofs = linked.dofs[:, :, 0].flatten()
        bcdof_list = []
        for linked_dof, link_ratio in zip(linked_dofs, link_ratios):
            link_bc = BCDof(var='u',
                            dof=linked_dof,
                            value=0,
                            link_dofs=[bot_dof, top_dof],
                            link_coeffs=[1 - link_ratio, link_ratio])
            bcdof_list.append(link_bc)
        return bcdof_list

    control_bc = Property(depends_on='CS,BC,GEO,MESH')
    '''Foxed boundary condition'''

    @cached_property
    def _get_control_bc(self):
        return BCSlice(slice=self.fe_grid[-1, 0, 0, -1, 0, 0],
                       var='u',
                       dims=[0],
                       value=self.w_max)

    uniform_control_bc = Property(depends_on='CS,BC,GEO,MESH')
    '''Foxed boundary condition'''

    @cached_property
    def _get_uniform_control_bc(self):
        return BCSlice(slice=self.fe_grid[-1, :, :, -1, :, :],
                       var='u',
                       dims=[0],
                       value=self.w_max)

    dots_grid = Property(Instance(DOTSGrid), depends_on='CS,MAT,GEO,MESH,FE')
    '''Discretization object.
    '''

    @cached_property
    def _get_dots_grid(self):
        cs = self.cross_section
        geo = self.geometry
        return DOTSGrid(L_x=cs.h,
                        L_y=geo.L,
                        L_z=cs.b,
                        n_x=self.n_e_x,
                        n_y=self.n_e_y,
                        n_z=self.n_e_z,
                        fets=self.fets_eval,
                        mats=self.mats_eval)

    fe_grid = Property

    def _get_fe_grid(self):
        return self.dots_grid.mesh

    tline = Instance(TLine)

    def _tline_default(self):
        t_max = 1.0
        d_t = 0.1
        return TLine(
            min=0.0,
            step=d_t,
            max=t_max,
            time_change_notifier=self.time_changed,
        )

    k_max = Int(200, ALG=True)
    tolerance = Float(1e-4, ALG=True)
    tloop = Property(Instance(TimeLoop),
                     depends_on='MAT,GEO,MESH,CS,TIME,ALG,BC')
    '''Algorithm controlling the time stepping.
    '''

    @cached_property
    def _get_tloop(self):
        k_max = self.k_max
        tolerance = self.tolerance
        return TimeLoop(ts=self.dots_grid,
                        k_max=k_max,
                        tolerance=tolerance,
                        tline=self.tline,
                        bc_mngr=self.bcond_mngr)

    def get_PW(self):
        record_dofs = np.unique(self.fe_grid[-1, :, :,
                                             -1, :, :].dofs[:, :, 0].flatten())
        Fd_int_t = np.array(self.tloop.F_int_record)
        Ud_t = np.array(self.tloop.U_record)
        F_int_t = np.sum(Fd_int_t[:, record_dofs], axis=1)
        U_t = Ud_t[:, record_dofs[0]]
        return F_int_t, U_t

    viz2d_classes = {
        'F-w': Viz2DForceDeflectionX,
        'load function': Viz2DLoadControlFunction,
    }

    traits_view = View(Item('mats_eval_type'), )

    tree_view = traits_view
Example #29
0
class MyVisuClass(HasTraits):

    # def __init__(self):

    scene = Instance(MlabSceneModel, ())
    # the layout of the dialog created

    view = View(Item('scene',
                     editor=SceneEditor(scene_class=Scene),
                     height=250,
                     width=300,
                     show_label=False),
                resizable=True)
    """
    view = View(Item('scene', editor=SceneEditor(),
                    height=250, width=300, show_label=False),
                resizable=True)
    """
    def redraw_scene(self):
        mlab.clf(figure=self.scene.mayavi_scene)
        mlab.figure(figure=self.scene.mayavi_scene, bgcolor=(0.15, 0.15, 0.15))

    def plot_deformation(self, _strdata):
        if os.name == 'nt':  # windows
            width_line = 1.0
        else:
            width_line = 2.0

        # elements + elemtext
        self.scene.disable_render = True
        self.def_factor = 1.0
        color_def = (0 / 256, 255 / 256, 255 / 256)  # cyan
        elems = _strdata.Elems.elements
        num_elem = len(elems)
        for i in range(num_elem):
            # detect node of each element
            n1 = _strdata.Nodes.findNodeById(elems[i].n1)
            n2 = _strdata.Nodes.findNodeById(elems[i].n2)
            # find node deformation vectors
            xlist = []
            ylist = []
            zlist = []

            xlist.append(n1.x + self.def_factor * n1.defX)
            ylist.append(n1.y + self.def_factor * n1.defY)
            zlist.append(0)
            xlist.append(n2.x + self.def_factor * n2.defX)
            ylist.append(n2.y + self.def_factor * n2.defY)
            zlist.append(0)

            window.elemdef.append(
                mlab.plot3d(xlist,
                            ylist,
                            zlist,
                            line_width=width_line,
                            opacity=0.8,
                            tube_radius=None,
                            color=color_def))
        self.scene.disable_render = False

    def plot_model_geometry(self, _strdata):
        mlab.clf(figure=self.scene.mayavi_scene)

        # node
        nodes = _strdata.Nodes.nodes
        num_node = len(nodes)
        node_scale_factor = 0.03
        xlist = []
        ylist = []
        zlist = []
        nidList = []
        for i in range(len(nodes)):
            xlist.append(nodes[i].x)
            ylist.append(nodes[i].y)
            # zlist.append(nodes[i].z)
            nidList.append(str(nodes[i].id))
        x = np.array(xlist)
        y = np.array(ylist)
        z = np.zeros(num_node)  # zero element for 2d analysis

        window.pts = mlab.points3d(x,
                                   y,
                                   z,
                                   figure=self.scene.mayavi_scene,
                                   resolution=16,
                                   scale_factor=node_scale_factor)

        self.scene.disable_render = True

        for i in range(len(nodes)):
            x = xlist[i]
            y = ylist[i]
            window.nodeText.append(
                mlab.text3d(x,
                            y,
                            0,
                            "N" + nidList[i],
                            figure=self.scene.mayavi_scene,
                            scale=0.08,
                            color=(1, 1, 1)))

        # elements + elemtext
        elems = _strdata.Elems.elements
        num_elem = len(elems)
        for i in range(num_elem):
            n1id = elems[i].n1
            n2id = elems[i].n2
            self.LinePlot(n1id, n2id, _strdata.Nodes)

            n1 = _strdata.Nodes.findNodeById(n1id)
            n2 = _strdata.Nodes.findNodeById(n2id)
            midpt = Node.getMidPointCoordinates(n1, n2)

            x = midpt[0]
            y = midpt[1]

            angle = 0.0

            window.elemText.append(
                mlab.text3d(x,
                            y,
                            0,
                            "E" + str(elems[i].id),
                            figure=self.scene.mayavi_scene,
                            scale=0.08,
                            color=(1, 1, 1),
                            orientation=(0, 0, angle)))
        self.scene.disable_render = False

        # loads
        load_scale_factor = 0.2
        color_load = (135 / 256, 206 / 256, 250 / 256)
        loads = _strdata.Loads.loads
        num_loads = len(loads)
        xlist = []
        ylist = []
        zlist = np.zeros(num_loads)
        ulist = []
        vlist = []
        wlist = np.zeros(num_loads)
        for i in range(num_loads):
            n = _strdata.Nodes.findNodeById(loads[i].nodeId)
            lx = loads[i].loadX
            ly = loads[i].loadY
            x = n.x - lx * load_scale_factor
            y = n.y - ly * load_scale_factor
            xlist.append(x)
            ylist.append(y)
            ulist.append(lx)
            vlist.append(ly)

            window.loadText.append(
                mlab.text3d(x,
                            y,
                            0,
                            str(loads[i].vecLength),
                            figure=self.scene.mayavi_scene,
                            scale=0.08,
                            color=color_load,
                            orientation=(0, 0, 0)))
        x = np.array(xlist)
        y = np.array(ylist)
        u = np.array(ulist)
        v = np.array(vlist)
        window.loadvecs = mlab.quiver3d(x,
                                        y,
                                        zlist,
                                        u,
                                        v,
                                        wlist,
                                        scale_factor=load_scale_factor,
                                        mode='2darrow',
                                        line_width=2.0,
                                        color=color_load)

        # constraints
        const_scale_factor = 0.2
        color_consts = (255 / 256, 0 / 256, 0 / 256)  # red
        consts = _strdata.Consts.constraints
        num_consts = len(consts)
        xlist = []
        ylist = []
        zlist = []  # np.zeros(num_consts)
        ulist = []
        vlist = []
        wlist = []  # np.zeros(num_consts)
        for i in range(num_consts):
            n = _strdata.Nodes.findNodeById(consts[i].nodeId)
            cx = consts[i].cX
            cy = consts[i].cY
            if cx == 0 and cy == 0:
                continue
            elif cx == 1 and cy == 1:
                x = n.x - cx * const_scale_factor
                y = n.y
                xlist.append(x)
                ylist.append(y)
                zlist.append(0)
                ulist.append(cx)
                vlist.append(0)
                wlist.append(0)
                x = n.x
                y = n.y - cy * const_scale_factor
                xlist.append(x)
                ylist.append(y)
                zlist.append(0)
                ulist.append(0)
                vlist.append(cy)
                wlist.append(0)
            else:
                x = n.x - cx * const_scale_factor
                y = n.y - cy * const_scale_factor
                xlist.append(x)
                ylist.append(y)
                zlist.append(0)
                ulist.append(cx)
                vlist.append(cy)
                wlist.append(0)

        x = np.array(xlist)
        y = np.array(ylist)
        z = np.array(zlist)
        u = np.array(ulist)
        v = np.array(vlist)
        w = np.array(wlist)
        window.constvecs = mlab.quiver3d(x,
                                         y,
                                         z,
                                         u,
                                         v,
                                         w,
                                         scale_factor=const_scale_factor,
                                         mode='arrow',
                                         resolution=32,
                                         color=color_consts)

        # global axes
        mlab.orientation_axes(figure=self.scene.mayavi_scene,
                              opacity=1.0,
                              line_width=1.0)

    @on_trait_change('scene.activated')
    def update_plot(self):
        self.redraw_scene()

    def reset_view_xy(self):
        mlab.view(0, 0)

    def showpip(self):
        mlab.show_pipeline()

    def toggle_nid(self):
        if window.tgl_nid == False:
            window.tgl_nid = True
            #window.pts.visible = True
            for i in range(len(window.nodeText)):
                window.nodeText[i].visible = True
        else:
            window.tgl_nid = False
            #window.pts.visible = False
            for i in range(len(window.nodeText)):
                window.nodeText[i].visible = False

    def toggle_eid(self):
        if window.tgl_eid == False:
            window.tgl_eid = True
            for i in range(len(window.elemText)):
                window.elemText[i].visible = True
        else:
            window.tgl_eid = False
            for i in range(len(window.elemText)):
                window.elemText[i].visible = False

    def toggle_load(self):
        if window.tgl_lval == False:
            window.tgl_lval = True
            window.loadvecs.visible = True
            for i in range(len(window.loadText)):
                window.loadText[i].visible = True
        else:
            window.tgl_lval = False
            window.loadvecs.visible = False
            for i in range(len(window.loadText)):
                window.loadText[i].visible = False

    def toggle_consts(self):
        if window.tgl_consts == False:
            window.tgl_consts = True
            window.constvecs.visible = True
        else:
            window.tgl_consts = False
            window.constvecs.visible = False

    def toggle_deformation(self):
        if window.tgl_deformation == False:
            window.tgl_deformation = True
            self.scene.disable_render = True
            for e in window.elemdef:
                e.visible = True
            self.scene.disable_render = False
        else:
            window.tgl_deformation = False
            self.scene.disable_render = True
            for e in window.elemdef:
                e.visible = False
            self.scene.disable_render = False

    @staticmethod
    def LinePlot(_sid, _eid, _nodes):
        if os.name == 'nt':  # windows
            width_line = 1.0
        else:
            width_line = 2.0

        spt = _nodes.findNodeById(_sid)
        ept = _nodes.findNodeById(_eid)
        mlab.plot3d([spt.x, ept.x], [spt.y, ept.y], [0.0, 0.0],
                    line_width=width_line,
                    opacity=1.0,
                    tube_radius=None)
Example #30
0
class TVTKClassChooser(HasTraits):

    # The selected object, is None if no valid class_name was made.
    object = Property

    # The TVTK class name to choose.
    class_name = Str('', desc='class name of TVTK class (case sensitive)')

    # The string to search for in the class docs -- the search supports
    # 'and' and 'or' keywords.
    search = Str('', desc='string to search in TVTK class documentation '\
                          'supports the "and" and "or" keywords. '\
                          'press <Enter> to start search. '\
                          'This is case insensitive.')

    clear_search = Button

    # The class documentation.
    doc = Str(_search_help_doc)

    # Completions for the choice of class.
    completions = List(Str)

    # List of available class names as strings.
    available = List(TVTK_CLASSES)

    ########################################
    # Private traits.

    finder = Instance(DocSearch)

    n_completion = Int(25)

    ########################################
    # View related traits.

    view = View(Group(Item(name='class_name',
                           editor=EnumEditor(name='available')),
                      Item(name='class_name',
                           has_focus=True
                           ),
                      Item(name='search',
                           editor=TextEditor(enter_set=True,
                                             auto_set=False)
                           ),
                      Item(name='clear_search',
                           show_label=False),
                      Item('_'),
                      Item(name='completions',
                           editor=ListEditor(columns=3),
                           style='readonly'
                           ),
                      Item(name='doc',
                           resizable=True,
                           label='Documentation',
                           style='custom')
                      ),
                id='tvtk_doc',
                resizable=True,
                width=800,
                height=600,
                title='TVTK class chooser',
                buttons = ["OK", "Cancel"]
                )
    ######################################################################
    # `object` interface.
    ######################################################################
    def __init__(self, **traits):
        super(TVTKClassChooser, self).__init__(**traits)
        self._orig_available = list(self.available)

    ######################################################################
    # Non-public interface.
    ######################################################################
    def _get_object(self):
        o = None
        if len(self.class_name) > 0:
            try:
                o = getattr(tvtk, self.class_name)()
            except (AttributeError, TypeError):
                pass
        return o

    def _class_name_changed(self, value):
        av = self.available
        comp = [x for x in av if x.startswith(value)]
        self.completions = comp[:self.n_completion]
        if len(comp) == 1 and value != comp[0]:
            self.class_name = comp[0]

        o = self.object
        if o is not None:
            self.doc = get_tvtk_class_doc(o)
        else:
            self.doc = _search_help_doc

    def _finder_default(self):
        return DocSearch()

    def _clear_search_fired(self):
        self.search = ''

    def _search_changed(self, value):
        if len(value) < 3:
            self.available = self._orig_available
            return

        f = self.finder
        result = f.search(str(value))
        if len(result) == 0:
            self.available = self._orig_available
        elif len(result) == 1:
            self.class_name = result[0]
        else:
            self.available = result
            self.completions = result[:self.n_completion]