class IrrigationArea(HasTraits):
    name = Str
    surface = Float(desc='Surface [ha]')
    crop = Enum('Alfalfa', 'Wheat', 'Cotton')
Example #2
0
class BaseLakeShoreController(CoreDevice):
    units = Enum('C', 'K')
    scan_func = 'update'

    input_a = Float
    input_b = Float
    setpoint1 = Float(auto_set=False, enter_set=True)
    setpoint1_readback = Float
    setpoint2 = Float(auto_set=False, enter_set=True)
    setpoint2_readback = Float
    range_tests = List
    num_inputs = Int
    ionames = List
    iolist = List
    iomap = List

    def load_additional_args(self, config):
        self.set_attribute(config, 'units', 'General', 'units', default='K')

        # [Range]
        # 1=v<10
        # 2=10<v<30
        # 3=v>30

        if config.has_section('Range'):
            items = config.items('Range')

        else:
            items = [(1, 'v<10'), (2, '10<v<30'), (3, 'v>30')]

        if items:
            self.range_tests = [RangeTest(*i) for i in items]

        if config.has_section('IOConfig'):
            iodict = dict(config.items('IOConfig'))
            self.num_inputs = int(iodict['num_inputs'])
            for i, tag in enumerate(string.ascii_lowercase[:self.num_inputs]):
                try:
                    self.ionames.append(iodict['input_{}_name'.format(tag)])
                except ValueError:
                    self.ionames.append('input_{}'.format(tag))
                self.iolist.append('input_{}'.format(tag))
                mapsetpoint = iodict['input_{}'.format(tag)]
                if mapsetpoint.lower() == 'none':
                    self.iomap.append(None)
                else:
                    self.iomap.append(mapsetpoint)
        else:
            self.num_inputs = 2
            self.ionames = ['', '', '', '']
            self.iomap = ['setpoint1', 'setpoint2', 'setpoint3', 'setpoint4']

        return True

    def initialize(self, *args, **kw):
        self.communicator.write_terminator = chr(10)  # line feed \n
        return super(BaseLakeShoreController, self).initialize(*args, **kw)

    def test_connection(self):
        self.tell('*CLS')
        resp = self.ask('*IDN?')
        return bool(IDN_RE.match(resp))

    def update(self, **kw):
        for tag in self.iolist:
            func = getattr(self, 'read_{}'.format(tag))
            v = func(**kw)
            setattr(self, tag, v)

        for tag in self.iomap:
            v = self.read_setpoint(tag)
            setattr(self, '{}_readback'.format(tag), v)

        return self.input_a

    def setpoints_achieved(self, tol=1):
        for i, (tag, key) in enumerate(zip(self.iomap,
                                           string.ascii_lowercase)):
            idx = i + 1
            v = self._read_input(key, self.units)
            if tag is not None:
                setpoint = getattr(self, tag)
                if abs(v - setpoint) > tol:
                    return
                else:
                    self.debug('setpoint {} achieved'.format(idx))

        return True

    @get_float(default=0)
    def read_setpoint(self, output, verbose=False):
        if output is not None:
            return self.ask('SETP? {}'.format(re.sub('[^0-9]', '', output)),
                            verbose=verbose)

    def set_setpoints(self, *setpoints):
        for i, v in enumerate(setpoints):
            if v is not None:
                idx = i + 1
                setattr(self, 'setpoint{}'.format(idx), v)

    def set_setpoint(self, v, output=1):
        self.set_range(v, output)
        self.tell('SETP {},{}'.format(output, v))

    def set_range(self, v, output):
        # if v <= 10:
        #     self.tell('RANGE {},{}'.format(output, 1))
        # elif 10 < v <= 30:
        #     self.tell('RANGE {},{}'.format(output, 2))
        # else:
        #     self.tell('RANGE {},{}'.format(output, 3))

        for r in self.range_tests:
            ra = r.test(v)
            if ra:
                self.tell('RANGE {},{}'.format(output, ra))
                break

        sleep(1)

    def read_input(self, v, **kw):
        if isinstance(v, int):
            v = string.ascii_lowercase[v - 1]
        return self._read_input(v, self.units, **kw)

    def read_input_a(self, **kw):
        return self._read_input('a', self.units, **kw)

    def read_input_b(self, **kw):
        return self._read_input('b', self.units, **kw)

    @get_float(default=0)
    def _read_input(self, tag, mode='C', verbose=False):
        return self.ask('{}RDG? {}'.format(mode, tag), verbose=verbose)

    def _setpoint1_changed(self):
        self.set_setpoint(self.setpoint1, 1)

    def _setpoint2_changed(self):
        self.set_setpoint(self.setpoint2, 2)
class SurfaceObject(Object):
    """Represent a solid object in a mayavi scene.

    Notes
    -----
    Doesn't automatically update plot because update requires both
    :attr:`points` and :attr:`tris`. Call :meth:`plot` after updating both
    attributes.
    """

    rep = Enum("Surface", "Wireframe")
    tris = Array(int, shape=(None, 3))

    surf = Instance(Surface)
    surf_rear = Instance(Surface)

    view = View(
        HGroup(Item('visible', show_label=False),
               Item('color', show_label=False), Item('opacity')))

    def __init__(self, block_behind=False, **kwargs):  # noqa: D102
        self._block_behind = block_behind
        self._deferred_tri_update = False
        super(SurfaceObject, self).__init__(**kwargs)

    def clear(self):  # noqa: D102
        if hasattr(self.src, 'remove'):
            self.src.remove()
        if hasattr(self.surf, 'remove'):
            self.surf.remove()
        if hasattr(self.surf_rear, 'remove'):
            self.surf_rear.remove()
        self.reset_traits(['src', 'surf'])

    @on_trait_change('scene.activated')
    def plot(self):
        """Add the points to the mayavi pipeline"""
        _scale = self.scene.camera.parallel_scale
        self.clear()

        if not np.any(self.tris):
            return

        fig = self.scene.mayavi_scene
        surf = dict(rr=self.points, tris=self.tris)
        normals = _create_mesh_surf(surf, fig=fig)
        self.src = normals.parent
        rep = 'wireframe' if self.rep == 'Wireframe' else 'surface'
        # Add the opaque "inside" first to avoid the translucent "outside"
        # from being occluded (gh-5152)
        if self._block_behind:
            surf_rear = pipeline.surface(normals,
                                         figure=fig,
                                         color=self.color,
                                         representation=rep,
                                         line_width=1)
            surf_rear.actor.property.frontface_culling = True
            self.surf_rear = surf_rear
            self.sync_trait('color',
                            self.surf_rear.actor.property,
                            mutual=False)
            self.sync_trait('visible', self.surf_rear, 'visible')
            self.surf_rear.actor.property.opacity = 1.
        surf = pipeline.surface(normals,
                                figure=fig,
                                color=self.color,
                                representation=rep,
                                line_width=1)
        surf.actor.property.backface_culling = True
        self.surf = surf
        self.sync_trait('visible', self.surf, 'visible')
        self.sync_trait('color', self.surf.actor.property, mutual=False)
        self.sync_trait('opacity', self.surf.actor.property)

        self.scene.camera.parallel_scale = _scale

    @on_trait_change('tris')
    def _update_tris(self):
        self._deferred_tris_update = True

    @on_trait_change('points')
    def _update_points(self):
        if Object._update_points(self):
            if self._deferred_tris_update:
                self.src.data.polys = self.tris
                self._deffered_tris_update = False
            self.src.update()  # necessary for SurfaceObject since Mayavi 4.5.0
Example #4
0
class Kit2FiffModel(HasPrivateTraits):
    """Data Model for Kit2Fiff conversion.

    - Markers are transformed into RAS coordinate system (as are the sensor
      coordinates).
    - Head shape digitizer data is transformed into neuromag-like space.
    """

    # Input Traits
    markers = Instance(CombineMarkersModel, ())
    sqd_file = File(exists=True, filter=kit_con_wildcard)
    allow_unknown_format = Bool(False)
    hsp_file = File(exists=True, filter=hsp_wildcard)
    fid_file = File(exists=True, filter=elp_wildcard)
    stim_coding = Enum(">", "<", "channel")
    stim_chs = Str("")
    stim_chs_array = Property(depends_on=['raw', 'stim_chs', 'stim_coding'])
    stim_chs_ok = Property(depends_on='stim_chs_array')
    stim_chs_comment = Property(depends_on='stim_chs_array')
    stim_slope = Enum("-", "+")
    stim_threshold = Float(1.)

    # Marker Points
    use_mrk = List(list(range(5)), desc="Which marker points to use for the "
                   "device head coregistration.")

    # Derived Traits
    mrk = Property(depends_on='markers.mrk3.points')

    # Polhemus Fiducials
    elp_raw = Property(depends_on=['fid_file'])
    hsp_raw = Property(depends_on=['hsp_file'])
    polhemus_neuromag_trans = Property(depends_on=['elp_raw'])

    # Polhemus data (in neuromag space)
    elp = Property(depends_on=['elp_raw', 'polhemus_neuromag_trans'])
    fid = Property(depends_on=['elp_raw', 'polhemus_neuromag_trans'])
    hsp = Property(depends_on=['hsp_raw', 'polhemus_neuromag_trans'])

    # trans
    dev_head_trans = Property(depends_on=['elp', 'mrk', 'use_mrk'])
    head_dev_trans = Property(depends_on=['dev_head_trans'])

    # event preview
    raw = Property(depends_on='sqd_file')
    misc_chs = Property(List, depends_on='raw')
    misc_chs_desc = Property(Str, depends_on='misc_chs')
    misc_data = Property(Array, depends_on='raw')
    can_test_stim = Property(Bool, depends_on='raw')

    # info
    sqd_fname = Property(Str, depends_on='sqd_file')
    hsp_fname = Property(Str, depends_on='hsp_file')
    fid_fname = Property(Str, depends_on='fid_file')
    can_save = Property(Bool, depends_on=['stim_chs_ok', 'fid',
                                          'elp', 'hsp', 'dev_head_trans'])

    # Show GUI feedback (like error messages and progress bar)
    show_gui = Bool(False)

    @cached_property
    def _get_can_save(self):
        """Only allow saving when all or no head shape elements are set."""
        if not self.stim_chs_ok:
            return False

        has_all_hsp = (np.any(self.dev_head_trans) and np.any(self.hsp) and
                       np.any(self.elp) and np.any(self.fid))
        if has_all_hsp:
            return True

        has_any_hsp = self.hsp_file or self.fid_file or np.any(self.mrk)
        return not has_any_hsp

    @cached_property
    def _get_can_test_stim(self):
        return self.raw is not None

    @cached_property
    def _get_dev_head_trans(self):
        if (self.mrk is None) or not np.any(self.fid):
            return np.eye(4)

        src_pts = self.mrk
        dst_pts = self.elp

        n_use = len(self.use_mrk)
        if n_use < 3:
            if self.show_gui:
                error(None, "Estimating the device head transform requires at "
                      "least 3 marker points. Please adjust the markers used.",
                      "Not Enough Marker Points")
            return
        elif n_use < 5:
            src_pts = src_pts[self.use_mrk]
            dst_pts = dst_pts[self.use_mrk]

        trans = fit_matched_points(src_pts, dst_pts, out='trans')
        return trans

    @cached_property
    def _get_elp(self):
        if self.elp_raw is None:
            return np.empty((0, 3))
        pts = self.elp_raw[3:8]
        pts = apply_trans(self.polhemus_neuromag_trans, pts)
        return pts

    @cached_property
    def _get_elp_raw(self):
        if not self.fid_file:
            return

        try:
            pts = _read_dig_points(self.fid_file)
            if len(pts) < 8:
                raise ValueError("File contains %i points, need 8" % len(pts))
        except Exception as err:
            if self.show_gui:
                error(None, str(err), "Error Reading Fiducials")
            self.reset_traits(['fid_file'])
            raise
        else:
            return pts

    @cached_property
    def _get_fid(self):
        if self.elp_raw is None:
            return np.empty((0, 3))
        pts = self.elp_raw[:3]
        pts = apply_trans(self.polhemus_neuromag_trans, pts)
        return pts

    @cached_property
    def _get_fid_fname(self):
        if self.fid_file:
            return os.path.basename(self.fid_file)
        else:
            return '-'

    @cached_property
    def _get_head_dev_trans(self):
        return inv(self.dev_head_trans)

    @cached_property
    def _get_hsp(self):
        if (self.hsp_raw is None) or not np.any(self.polhemus_neuromag_trans):
            return np.empty((0, 3))
        else:
            pts = apply_trans(self.polhemus_neuromag_trans, self.hsp_raw)
            return pts

    @cached_property
    def _get_hsp_fname(self):
        if self.hsp_file:
            return os.path.basename(self.hsp_file)
        else:
            return '-'

    @cached_property
    def _get_hsp_raw(self):
        fname = self.hsp_file
        if not fname:
            return

        try:
            pts = _read_dig_points(fname)
            n_pts = len(pts)
            if n_pts > KIT.DIG_POINTS:
                msg = ("The selected head shape contains {n_in} points, "
                       "which is more than the recommended maximum ({n_rec}). "
                       "The file will be automatically downsampled, which "
                       "might take a while. A better way to downsample is "
                       "using FastScan.".
                       format(n_in=n_pts, n_rec=KIT.DIG_POINTS))
                if self.show_gui:
                    information(None, msg, "Too Many Head Shape Points")
                pts = _decimate_points(pts, 5)

        except Exception as err:
            if self.show_gui:
                error(None, str(err), "Error Reading Head Shape")
            self.reset_traits(['hsp_file'])
            raise
        else:
            return pts

    @cached_property
    def _get_misc_chs(self):
        if not self.raw:
            return
        return [i for i, ch in enumerate(self.raw.info['chs']) if
                ch['kind'] == FIFF.FIFFV_MISC_CH]

    @cached_property
    def _get_misc_chs_desc(self):
        if self.misc_chs is None:
            return "No SQD file selected..."
        elif np.all(np.diff(self.misc_chs) == 1):
            return "%i:%i" % (self.misc_chs[0], self.misc_chs[-1] + 1)
        else:
            return "%i... (discontinuous)" % self.misc_chs[0]

    @cached_property
    def _get_misc_data(self):
        if not self.raw:
            return
        if self.show_gui:
            # progress dialog with indefinite progress bar
            prog = ProgressDialog(title="Loading SQD data...",
                                  message="Loading stim channel data from SQD "
                                  "file ...")
            prog.open()
            prog.update(0)
        else:
            prog = None

        try:
            data, times = self.raw[self.misc_chs]
        except Exception as err:
            if self.show_gui:
                error(None, "Error reading SQD data file: %s (Check the "
                      "terminal output for details)" % str(err),
                      "Error Reading SQD File")
            raise
        finally:
            if self.show_gui:
                prog.close()
        return data

    @cached_property
    def _get_mrk(self):
        return apply_trans(als_ras_trans, self.markers.mrk3.points)

    @cached_property
    def _get_polhemus_neuromag_trans(self):
        if self.elp_raw is None:
            return
        nasion, lpa, rpa = apply_trans(als_ras_trans, self.elp_raw[:3])
        trans = get_ras_to_neuromag_trans(nasion, lpa, rpa)
        return np.dot(trans, als_ras_trans)

    @cached_property
    def _get_raw(self):
        if not self.sqd_file:
            return
        try:
            return RawKIT(self.sqd_file, stim=None,
                          allow_unknown_format=self.allow_unknown_format)
        except UnsupportedKITFormat as exception:
            warning(
                None,
                "The selected SQD file is written in an old file format (%s) "
                "that is not officially supported. Confirm that the results "
                "are as expected. This warning is displayed only once per "
                "session." % (exception.sqd_version,),
                "Unsupported SQD File Format")
            self.allow_unknown_format = True
            return self._get_raw()
        except Exception as err:
            self.reset_traits(['sqd_file'])
            if self.show_gui:
                error(None, "Error reading SQD data file: %s (Check the "
                      "terminal output for details)" % str(err),
                      "Error Reading SQD File")
            raise

    @cached_property
    def _get_sqd_fname(self):
        if self.sqd_file:
            return os.path.basename(self.sqd_file)
        else:
            return '-'

    @cached_property
    def _get_stim_chs_array(self):
        if self.raw is None:
            return
        elif not self.stim_chs.strip():
            picks = _default_stim_chs(self.raw.info)
        else:
            try:
                picks = eval("r_[%s]" % self.stim_chs, vars(np))
                if picks.dtype.kind != 'i':
                    raise TypeError("Need array of int")
            except Exception:
                return None

        if self.stim_coding == '<':  # Big-endian
            return picks[::-1]
        else:
            return picks

    @cached_property
    def _get_stim_chs_comment(self):
        if self.raw is None:
            return ""
        elif not self.stim_chs_ok:
            return "Invalid!"
        elif not self.stim_chs.strip():
            return "Default:  The first 8 MISC channels"
        else:
            return "Ok:  %i channels" % len(self.stim_chs_array)

    @cached_property
    def _get_stim_chs_ok(self):
        return self.stim_chs_array is not None

    def clear_all(self):
        """Clear all specified input parameters."""
        self.markers.clear = True
        self.reset_traits(['sqd_file', 'hsp_file', 'fid_file', 'use_mrk'])

    def get_event_info(self):
        """Count events with current stim channel settings.

        Returns
        -------
        event_count : Counter
            Counter mapping event ID to number of occurrences.
        """
        if self.misc_data is None:
            return
        idx = [self.misc_chs.index(ch) for ch in self.stim_chs_array]
        data = self.misc_data[idx]
        if self.stim_coding == 'channel':
            coding = 'channel'
        else:
            coding = 'binary'
        stim_ch = _make_stim_channel(data, self.stim_slope,
                                     self.stim_threshold, coding,
                                     self.stim_chs_array)
        events = _find_events(stim_ch, self.raw.first_samp, consecutive=True,
                              min_samples=3)
        return Counter(events[:, 2])

    def get_raw(self, preload=False):
        """Create a raw object based on the current model settings."""
        if not self.can_save:
            raise ValueError("Not all necessary parameters are set")

        # stim channels and coding
        if self.stim_coding == 'channel':
            stim_code = 'channel'
        elif self.stim_coding in '<>':
            stim_code = 'binary'
        else:
            raise RuntimeError("stim_coding=%r" % self.stim_coding)

        logger.info("Creating raw with stim=%r, slope=%r, stim_code=%r, "
                    "stimthresh=%r", self.stim_chs_array, self.stim_slope,
                    stim_code, self.stim_threshold)
        raw = RawKIT(self.sqd_file, preload=preload, stim=self.stim_chs_array,
                     slope=self.stim_slope, stim_code=stim_code,
                     stimthresh=self.stim_threshold,
                     allow_unknown_format=self.allow_unknown_format)

        if np.any(self.fid):
            raw.info['dig'] = _make_dig_points(self.fid[0], self.fid[1],
                                               self.fid[2], self.elp,
                                               self.hsp)
            raw.info['dev_head_t'] = Transform('meg', 'head',
                                               self.dev_head_trans)
        return raw
Example #5
0
class Person(HasStrictTraits):
    """Defines some sample data to display in the TableEditor."""

    name = Str()
    age = Int()
    gender = Enum('Male', 'Female')
Example #6
0
class MultiCanvasManager(DataElement):
    """ Handles multiple ConstraintsPlotContainerManager.

    Offers the ability to use multiple constraint based enable canvases and to
    specify a strategy for how to choose where to add a plot. In mode 0 or if
    there is only 1 container, all plots are added to container 0. In mode 1,
    all plots are added to the last used plot, unless it contains more than the
    overflow_limit. In that case, the plot is added to the next container if
    any. In mode 2, the container used is the one immediately after the last
    one used.
    """
    #: All plot canvases
    container_managers = List(Instance(ConstraintsPlotContainerManager))

    #: Number of available canvases
    num_container_managers = Int(DEFAULT_NUM_CONTAINERS)

    #: Type of container manager: how to lay plots out
    container_layout_type = Enum(["horizontal", "vertical"])

    #: Mode to auto-select a container: 0 -> first row, 1 -> last row (with
    # overflow), 2 -> new row
    multi_container_mode = Enum(list(range(NUM_MODES)))

    #: Number of plots beyond which to overflow to next container (mode 1 only)
    overflow_limit = Int(DEFAULT_OVERFLOW_SIZE)

    #: Number of plots for each container
    container_content = Property(Dict, depends_on="container_managers")

    # Container padding parameters --------------------------------------------

    # Outer paddings:
    padding_top = Int(0)

    padding_bottom = Int(5)

    padding_left = Int(0)

    padding_right = Int(0)

    # Padding in between plots:
    layout_spacing = Int(60)

    layout_margins = Int(60)

    def __init__(self, **traits):
        super(MultiCanvasManager, self).__init__(**traits)

        # Create a trait pointing to each of the containers so it can be
        # displayed in the DFPlotter view:
        all_containers = {}
        for i, container in enumerate(self.container_managers):
            self.add_trait(CONTAINER_TRAIT_NAME.format(i),
                           Instance(ConstraintsPlotContainerManager))
            all_containers[CONTAINER_TRAIT_NAME.format(i)] = container

        self.trait_set(**all_containers)
        self._initialize_all_managers()

    # Public interface --------------------------------------------------------

    def add_plot_to_container(self, desc, position=None, container=None):
        """ Insert the plot in the correct container.

        Parameters
        ----------
        desc : PlotDescriptor
            Descriptor of the plot to add.

        position : int or None
            Position of the plot. Leave as None to append.

        container : None or int or ConstraintsPlotContainerManager
            Container to add the plot to.
        """
        key = self.build_container_key(desc)

        if container is None:
            container = self.get_container_for_plot(desc)
        elif isinstance(container, int):
            container = self.container_managers[container]

        container.add_plot(key, desc.plot, position=position)

        # ...and hide the plot if it is supposed to be hidden
        if not desc.visible:
            container.hide_plot(key)

    def remove_plot_from_container(self, desc, container=None):
        """ Remove the plot from corresponding container.

        Parameters
        ----------
        desc : PlotDescriptor
            Descriptor of the plot to add.

        container : None or ConstraintsPlotContainerManager
            Container to add the plot to.
        """
        key = self.build_container_key(desc)

        if container is None:
            container = self.get_container_for_plot(desc)
        elif isinstance(container, int):
            container = self.container_managers[container]

        if key in container.plot_map:
            container.delete_plot(key, desc.plot)

    def get_container_for_plot(self, desc):
        """ Return the container a plot should be in.
        """
        idx = desc.container_idx
        if idx < 0:
            idx = desc.container_idx = self._get_container_idx()

        if idx < len(self.container_managers):
            return self.container_managers[idx]
        else:
            # This could happen when serializing a project, and re-opening it
            # with different preferences, where the number of containers is
            # less.
            msg = "Failed to find the requested container for {}: sending " \
                  "the plot to the last container."
            logger.info(msg)
            desc.container_idx = len(self.container_managers) - 1
            return self.container_managers[-1]

    @staticmethod
    def build_container_key(desc):
        # Isolated in a method to allow changing the mapping.
        return desc.plot_type, desc.id

    # Private interface -------------------------------------------------------

    def _initialize_all_managers(self):
        """ Create the container inside each container manager to allow plots
        to be added
        """
        for manager in self.container_managers:
            manager.init()

    def _get_container_idx(self):
        """ Compute index of container to use based on mode and current state.
        """
        if self.multi_container_mode == 0 or len(self.container_managers) == 1:
            return 0

        content = self.container_content
        used_containers = [key for key, val in content.items() if val]
        if not used_containers:
            return 0

        container_idx = max(used_containers)
        if self.multi_container_mode == 1:
            if content[container_idx] >= self.overflow_limit:
                container_idx += 1
        elif self.multi_container_mode == 2:
            container_idx += 1
        else:
            msg = "Multi-container mode {} not supported."
            msg = msg.format(self.multi_container_mode)
            logger.exception(msg)
            raise NotImplementedError(msg)

        if container_idx == len(self.container_managers):
            container_idx -= 1

        return container_idx

    # Traits listeners --------------------------------------------------------

    def _num_container_managers_changed(self):
        self.container_managers = self._container_managers_default()
        self._initialize_all_managers()

    # Traits property getters/setters -----------------------------------------

    def _get_container_content(self):
        counts = {}
        for i, container in enumerate(self.container_managers):
            counts[i] = len(container.plot_map)

        return counts

    # Traits initialization methods -------------------------------------------

    def _container_managers_default(self):
        managers = [
            ConstraintsPlotContainerManager(
                layout_type=self.container_layout_type,
                padding_top=self.padding_top,
                padding_bottom=self.padding_bottom,
                padding_left=self.padding_left,
                padding_right=self.padding_right,
                layout_spacing=self.layout_spacing,
                layout_margins=self.layout_margins)
            for _ in range(self.num_container_managers)
        ]

        return managers
Example #7
0
        def make_gui_items(pname, param, first_call=False):
            'Builds list of GUI items from AMI parameter dictionary.'

            gui_items = []
            new_traits = []
            if (isinstance(param, AMIParameter)):
                pusage = param.pusage
                if (pusage == 'In' or pusage == 'InOut'):
                    if (param.ptype == 'Boolean'):
                        new_traits.append((pname, Bool(param.pvalue)))
                        gui_items.append(
                            Item(pname, tooltip=param.pdescription))
                    else:
                        pformat = param.pformat
                        if (pformat == 'Range'):
                            new_traits.append((pname,
                                               Range(param.pmin, param.pmax,
                                                     param.pvalue)))
                            gui_items.append(
                                Item(pname, tooltip=param.pdescription))
                        elif (pformat == 'List'):
                            list_tips = param.plist_tip
                            default = param.pdefault
                            if (list_tips):
                                # The attempt, below, doesn't work.
                                # Prevent alphabetic sorting of list tips by Traits/UI machinery.
                                # i = 0
                                # tmp_tips = []
                                # for list_tip in list_tips:
                                #     i += 1
                                #     tmp_tips.append("{:02d}:{}".format(i, list_tip))
                                tmp_dict = {}
                                # tmp_dict.update(zip(tmp_tips, param.pvalue))
                                tmp_dict.update(zip(list_tips, param.pvalue))
                                val = tmp_dict.keys()[0]
                                if (default):
                                    for tip in tmp_dict:
                                        if (tmp_dict[tip] == default):
                                            val = tip
                                            break
                                new_traits.append((pname, Trait(val,
                                                                tmp_dict)))
                            else:
                                val = param.pvalue[0]
                                if (default):
                                    val = default
                                new_traits.append(
                                    (pname, Enum([val] + param.pvalue)))
                            gui_items.append(
                                Item(pname, tooltip=param.pdescription))
                        else:  # Value
                            new_traits.append((pname, param.pvalue))
                            gui_items.append(
                                Item(pname,
                                     style='readonly',
                                     tooltip=param.pdescription))
            else:  # subparameter branch
                subparam_names = param.keys()
                subparam_names.sort()
                sub_items = []
                group_desc = ''

                # Build GUI items for this branch.
                for subparam_name in subparam_names:
                    if (subparam_name == 'description'):
                        group_desc = param[subparam_name]
                    else:
                        tmp_items, tmp_traits = make_gui_items(
                            subparam_name, param[subparam_name])
                        sub_items.extend(tmp_items)
                        new_traits.extend(tmp_traits)

                # Put all top-level ungrouped parameters in a single VGroup.
                top_lvl_params = []
                sub_params = []
                for item in sub_items:
                    if (isinstance(item, Item)):
                        top_lvl_params.append(item)
                    else:
                        sub_params.append(item)
                sub_items = [Group(top_lvl_params)] + sub_params

                # Make the top-level group an HGroup; all others VGroups (default).
                if (first_call):
                    gui_items.append(
                        Group([Item(label=group_desc)] + sub_items,
                              label=pname,
                              show_border=True,
                              orientation='horizontal'))
                else:
                    gui_items.append(
                        Group([Item(label=group_desc)] + sub_items,
                              label=pname,
                              show_border=True))

            return gui_items, new_traits
class ToolkitEditorFactory(EditorFactory):
    """ Editor factory for buttons.
    """

    # -------------------------------------------------------------------------
    #  Trait definitions:
    # -------------------------------------------------------------------------

    # Value to set when the button is clicked
    value = Property()

    # Optional label for the button
    label = Str()

    # The name of the external object trait that the button label is synced to
    label_value = Str()

    # The name of the trait on the object that contains the list of possible
    # values.  If this is set, then the value, label, and label_value traits
    # are ignored; instead, they will be set from this list.  When this button
    # is clicked, the value set will be the one selected from the drop-down.
    values_trait = Either(None, Str)

    # (Optional) Image to display on the button
    image = Image

    # Extra padding to add to both the left and the right sides
    width_padding = Range(0, 31, 7)

    # Extra padding to add to both the top and the bottom sides
    height_padding = Range(0, 31, 5)

    # Presentation style
    style = Enum("button", "radio", "toolbar", "checkbox")

    # Orientation of the text relative to the image
    orientation = Enum("vertical", "horizontal")

    # The optional view to display when the button is clicked:
    view = AView

    # -------------------------------------------------------------------------
    #  Traits view definition:
    # -------------------------------------------------------------------------

    traits_view = View(["label", "value", "|[]"])

    def _get_value(self):
        return self._value

    def _set_value(self, value):
        self._value = value
        if isinstance(value, str):
            try:
                self._value = int(value)
            except:
                try:
                    self._value = float(value)
                except:
                    pass

    def __init__(self, **traits):
        self._value = 0
        super(ToolkitEditorFactory, self).__init__(**traits)
Example #9
0
class MRISubjectSource(HasPrivateTraits):
    """Find subjects in SUBJECTS_DIR and select one.

    Parameters
    ----------
    subjects_dir : directory
        SUBJECTS_DIR.
    subject : str
        Subject, corresponding to a folder in SUBJECTS_DIR.
    """

    refresh = Event(desc="Refresh the subject list based on the directory "
                    "structure of subjects_dir.")

    # settings
    subjects_dir = Directory(exists=True)
    subjects = Property(List(Str), depends_on=['subjects_dir', 'refresh'])
    subject = Enum(values='subjects')
    use_high_res_head = Bool(True)

    # info
    can_create_fsaverage = Property(Bool,
                                    depends_on=['subjects_dir', 'subjects'])
    subject_has_bem = Property(Bool,
                               depends_on=['subjects_dir', 'subject'],
                               desc="whether the subject has a file matching "
                               "the bem file name pattern")
    bem_pattern = Property(depends_on='mri_dir')

    @cached_property
    def _get_can_create_fsaverage(self):
        if not os.path.exists(self.subjects_dir):
            return False
        if 'fsaverage' in self.subjects:
            return False
        return True

    @cached_property
    def _get_mri_dir(self):
        if not self.subject:
            return
        elif not self.subjects_dir:
            return
        else:
            return os.path.join(self.subjects_dir, self.subject)

    @cached_property
    def _get_subjects(self):
        sdir = self.subjects_dir
        is_dir = sdir and os.path.isdir(sdir)
        if is_dir:
            dir_content = os.listdir(sdir)
            subjects = [s for s in dir_content if _is_mri_subject(s, sdir)]
            if len(subjects) == 0:
                subjects.append('')
        else:
            subjects = ['']

        return subjects

    @cached_property
    def _get_subject_has_bem(self):
        if not self.subject:
            return False
        return _mri_subject_has_bem(self.subject, self.subjects_dir)

    def create_fsaverage(self):  # noqa: D102
        if not self.subjects_dir:
            err = ("No subjects directory is selected. Please specify "
                   "subjects_dir first.")
            raise RuntimeError(err)

        mne_root = get_mne_root()
        if mne_root is None:
            err = ("MNE contains files that are needed for copying the "
                   "fsaverage brain. Please install MNE and try again.")
            raise RuntimeError(err)
        fs_home = get_fs_home()
        if fs_home is None:
            err = ("FreeSurfer contains files that are needed for copying the "
                   "fsaverage brain. Please install FreeSurfer and try again.")
            raise RuntimeError(err)

        create_default_subject(mne_root,
                               fs_home,
                               subjects_dir=self.subjects_dir)
        self.refresh = True
        self.subject = 'fsaverage'
Example #10
0
class Circle(Component):
    """
    The circle moves with the mouse cursor but leaves a translucent version of
    itself in its original position until the mouse button is released.
    """

    color = (0.6, 0.7, 1.0, 1.0)
    bgcolor = "none"

    normal_pointer = Pointer("arrow")
    moving_pointer = Pointer("hand")

    prev_x = Float
    prev_y = Float

    shadow_type = Enum("light", "dashed")
    shadow = Instance(Component)

    resizable = ""

    def __init__(self, **traits):
        Component.__init__(self, **traits)
        self.pointer = self.normal_pointer

    def _draw_mainlayer(self, gc, view_bounds=None, mode="default"):
        with gc:
            gc.set_fill_color(self.color)
            dx, dy = self.bounds
            x, y = self.position
            radius = min(dx / 2.0, dy / 2.0)
            gc.arc(x + dx / 2.0, y + dy / 2.0, radius, 0.0, 2 * 3.14159)
            gc.fill_path()

    def normal_left_down(self, event):
        self.event_state = "moving"
        self.pointer = self.moving_pointer
        event.window.set_mouse_owner(self, event.net_transform())

        # Create our shadow
        if self.shadow_type == "light":
            klass = LightCircle
        else:
            klass = DashedCircle
        self.shadow = klass(
            bounds=self.bounds,
            position=self.position,
            color=(1.0, 1.0, 1.0, 1.0),
        )
        self.container.insert(0, self.shadow)
        x, y = self.position
        self.prev_x = event.x
        self.prev_y = event.y

    def moving_mouse_move(self, event):
        self.position = [
            self.x + (event.x - self.prev_x),
            self.y + (event.y - self.prev_y),
        ]
        self.prev_x = event.x
        self.prev_y = event.y
        self.request_redraw()

    def moving_left_up(self, event):
        self.event_state = "normal"
        self.pointer = self.normal_pointer
        event.window.set_mouse_owner(None)
        event.window.redraw()
        # Remove our shadow
        self.container.remove(self.shadow)

    def moving_mouse_leave(self, event):
        self.moving_left_up(event)
Example #11
0
class TextBoxOverlay(AbstractOverlay):
    """ Draws a box with text in it.
    """

    #### Configuration traits ##################################################

    # The text to display in the box.
    text = Str

    # The font to use for the text.
    font = KivaFont("modern 12")

    # The background color for the box (overrides AbstractOverlay).
    bgcolor = ColorTrait("transparent")

    # The alpha value to apply to **bgcolor**
    alpha = Trait(1.0, None, Float)

    # The color of the outside box.
    border_color = ColorTrait("dodgerblue")

    # The thickness of box border.
    border_size = Int(1)

    # Number of pixels of padding around the text within the box.
    padding = Int(5)

    # Alignment of the text in the box:
    #
    # * "ur": upper right
    # * "ul": upper left
    # * "ll": lower left
    # * "lr": lower right
    align = Enum("ur", "ul", "ll", "lr")

    # This allows subclasses to specify an alternate position for the root
    # of the text box.  Must be a sequence of length 2.
    alternate_position = Any

    #### Public 'AbstractOverlay' interface ####################################

    def overlay(self, component, gc, view_bounds=None, mode="normal"):
        """ Draws the box overlaid on another component.

        Overrides AbstractOverlay.
        """

        if not self.visible:
            return

        # draw the label on a transparent box. This allows us to draw
        # different shapes and put the text inside it without the label
        # filling a rectangle on top of it
        label = Label(text=self.text,
                      font=self.font,
                      bgcolor="transparent",
                      margin=5)
        width, height = label.get_width_height(gc)

        valign, halign = self.align

        if self.alternate_position:
            x, y = self.alternate_position
            if valign == "u":
                y += self.padding
            else:
                y -= self.padding + height

            if halign == "r":
                x += self.padding
            else:
                x -= self.padding + width
        else:
            if valign == "u":
                y = component.y2 - self.padding - height
            else:
                y = component.y + self.padding

            if halign == "r":
                x = component.x2 - self.padding - width
            else:
                x = component.x + self.padding

        # attempt to get the box entirely within the component
        if x + width > component.width:
            x = max(0, component.width - width)
        if y + height > component.height:
            y = max(0, component.height - height)
        elif y < 0:
            y = 0

        # apply the alpha channel
        color = self.bgcolor_
        if self.bgcolor != "transparent":
            if self.alpha:
                color = list(self.bgcolor_)
                if len(color) == 4:
                    color[3] = self.alpha
                else:
                    color += [self.alpha]

        with gc:
            gc.translate_ctm(x, y)

            gc.set_line_width(self.border_size)
            gc.set_stroke_color(self.border_color_)
            gc.set_fill_color(color)

            # draw a rounded rectangle
            x = y = 0
            end_radius = 8.0
            gc.begin_path()
            gc.move_to(x + end_radius, y)
            gc.arc_to(x + width, y, x + width, y + end_radius, end_radius)
            gc.arc_to(x + width, y + height, x + width - end_radius,
                      y + height, end_radius)
            gc.arc_to(x, y + height, x, y, end_radius)
            gc.arc_to(x, y, x + width + end_radius, y, end_radius)
            gc.draw_path()

            label.draw(gc)
Example #12
0
class MATS1DDamage(MATS1DEval):
    '''
    Scalar Damage Model.
    '''

    E = Float(
        1.,  # 34e+3,
        modified=True,
        label="E",
        desc="Young's Modulus",
        enter_set=True,
        auto_set=False)

    epsilon_0 = Float(
        1.,  # 59e-6,
        modified=True,
        label="eps_0",
        desc="Breaking Strain",
        enter_set=True,
        auto_set=False)

    epsilon_f = Float(
        1.,  # 191e-6,
        modified=True,
        label="eps_f",
        desc="Shape Factor",
        enter_set=True,
        auto_set=False)

    stiffness = Enum("secant", "algorithmic", modified=True)

    # This event can be used by the clients to trigger an action upon
    # the completed reconfiguration of the material model
    #
    changed = Event

    #--------------------------------------------------------------------------
    # View specification
    #--------------------------------------------------------------------------

    traits_view = View(Group(Group(Item('E'),
                                   Item('epsilon_0'),
                                   Item('epsilon_f'),
                                   label='Material parameters',
                                   show_border=True),
                             Group(
                                 Item('stiffness', style='custom'),
                                 Spring(resizable=True),
                                 label='Configuration parameters',
                                 show_border=True,
                             ),
                             layout='tabbed'),
                       resizable=True)

    #-------------------------------------------------------------------------
    # Setup for computation within a supplied spatial context
    #-------------------------------------------------------------------------

    def get_state_array_size(self):
        '''
        Give back the nuber of floats to be saved
        @param sctx:spatial context
        '''
        return 2

    def new_cntl_var(self):
        return np.zeros(1, np.float_)

    def new_resp_var(self):
        return np.zeros(1, np.float_)

    #-------------------------------------------------------------------------
    # Evaluation - get the corrector and predictor
    #-------------------------------------------------------------------------

    def get_corr_pred(self, sctx, eps_app_eng, d_eps, tn, tn1, eps_avg=None):
        '''
        Corrector predictor computation.
        @param eps_app_eng input variable - engineering strain
        '''
        if eps_avg == None:
            eps_avg = eps_app_eng

        E = self.E
        D_el = np.array([E])

        if sctx.update_state_on:

            kappa_n = sctx.mats_state_array[0]
            kappa_k = sctx.mats_state_array[1]
            sctx.mats_state_array[0] = kappa_k

        kappa_k, omega = self._get_state_variables(sctx, eps_avg)
        sctx.mats_state_array[1] = kappa_k

        if self.stiffness == "algorithmic":
            D_e_dam = np.array(
                [self._get_alg_stiffness(sctx, eps_app_eng, kappa_k, omega)])
        else:
            D_e_dam = np.array([(1 - omega) * D_el])

        sigma = np.dot(np.array([(1 - omega) * D_el]), eps_app_eng)

        # print the stress you just computed and the value of the apparent E

        return sigma, D_e_dam

    #--------------------------------------------------------------------------
    # Subsidiary methods realizing configurable features
    #--------------------------------------------------------------------------

    def _get_state_variables(self, sctx, eps):

        kappa_n, kappa_k = sctx.mats_state_array

        kappa_k = max(abs(eps), kappa_n)

        omega = self._get_omega(sctx, kappa_k)

        return kappa_k, omega

    def _get_omega(self, sctx, kappa):
        epsilon_0 = self.epsilon_0
        epsilon_f = self.epsilon_f
        if kappa >= epsilon_0:
            return 1. - epsilon_0 / kappa * exp(
                -1 * (kappa - epsilon_0) / epsilon_f)
        else:
            return 0.

    def _get_alg_stiffness(self, sctx, eps_app_eng, e_max, omega):
        E = self.E
        D_el = np.array([E])
        epsilon_0 = self.epsilon_0
        epsilon_f = self.epsilon_f
        dodk = (epsilon_0 /
                (e_max * e_max) * exp(-(e_max - epsilon_0) / epsilon_f) +
                epsilon_0 / e_max / epsilon_f *
                exp(-(e_max - epsilon_0) / epsilon_f))
        D_alg = (1 - omega) * D_el - D_el * eps_app_eng * dodk
        return D_alg

    #--------------------------------------------------------------------------
    # Response trace evaluators
    #--------------------------------------------------------------------------

    def get_omega(self, sctx, eps_app_eng, eps_avg=None):
        if eps_avg == None:
            eps_avg = eps_app_eng
        return self._get_omega(sctx, eps_avg)

    # Declare and fill-in the rte_dict - it is used by the clients to
    # assemble all the available time-steppers.
    #
    rte_dict = Trait(Dict)

    def _rte_dict_default(self):
        return {
            'sig_app': self.get_sig_app,
            'eps_app': self.get_eps_app,
            'omega': self.get_omega
        }

    #-------------------------------------------------------------------------
    # List of response tracers to be constructed within the mats_explorer
    #-------------------------------------------------------------------------
    def _get_explorer_rtrace_list(self):
        '''Return the list of relevant tracers to be used in mats_explorer.
        '''
        return []

    def _get_explorer_config(self):
        from ibvpy.api import TLine, RTDofGraph, BCDof
        ec = super(MATS1DDamage, self)._get_explorer_config()
        ec['mats_eval'] = MATS1DDamage(E=1.0, epsilon_0=1.0, epsilon_f=5)
        ec['bcond_list'] = [
            BCDof(var='u',
                  dof=0,
                  value=1.7,
                  time_function=lambda t: (1 + 0.1 * t) * sin(t))
        ]
        ec['tline'] = TLine(step=0.1, max=10)
        ec['rtrace_list'] = [
            RTDofGraph(name='strain - stress',
                       var_x='eps_app',
                       idx_x=0,
                       var_y='sig_app',
                       idx_y=0,
                       record_on='update'),
            RTDofGraph(name='time - damage',
                       var_x='time',
                       idx_x=0,
                       var_y='omega',
                       idx_y=0,
                       record_on='update')
        ]
        return ec
Example #13
0
class ToolkitEditorFactory(EditorFactory):
    """ Editor factory for instance editors.
    """

    # -------------------------------------------------------------------------
    #  Trait definitions:
    # -------------------------------------------------------------------------

    #: List of items describing the types of selectable or editable instances
    values = List(InstanceChoiceItem)

    #: Extended name of the context object trait containing the list of types of
    #: selectable or editable instances
    name = Str

    #: Is the current value of the object trait editable (vs. merely
    #: selectable)?
    editable = Bool(True)

    #: Should the object trait value be selectable from a list of objects (a
    #: value of True forces a selection list to be displayed, while a value of
    #: False displays a selection list only if at least one object in the list
    #: of possible object values is selectable):
    selectable = Bool(False)

    #: Should the editor support drag and drop of objects to set the trait value
    #: (a value of True forces the editor to allow drag and drop, while a value
    #: of False only supports drag and drop if at least one item in the list of
    #: possible objects supports drag and drop):
    droppable = Bool(False)

    #: Should factory-created objects be cached?
    cachable = Bool(True)

    #: Optional label for button
    label = Unicode

    #: Optional instance view to use
    view = AView

    #: Extended name of the context object trait containing the view, or name of
    #: the view, to use
    view_name = Str

    #: The ID to use with the view
    id = Str

    #: Kind of pop-up editor (live, modal, nonmodal, wizard)
    kind = AKind

    #: The orientation of the instance editor relative to the instance selector
    orientation = Enum("default", "horizontal", "vertical")

    #: The default adapter class used to create InstanceChoice compatible
    #: adapters for instance objects:
    adapter = Type(InstanceChoice, allow_none=False)

    # -------------------------------------------------------------------------
    #  Traits view definitions:
    # -------------------------------------------------------------------------

    traits_view = View(
        [
            ["label{Button label}", "view{View name}", "|[]"],
            ["kind@", "|[Pop-up editor style]<>"],
        ]
    )
class DataRange1D(BaseDataRange):
    """ Represents a 1-D data range.
    """

    # The actual value of the lower bound of this range (overrides
    # AbstractDataRange). To set it, use **low_setting**.
    low = Property
    # The actual value of the upper bound of this range (overrides
    # AbstractDataRange). To set it, use **high_setting**.
    high = Property

    # Property for the lower bound of this range (overrides AbstractDataRange).
    #
    # * 'auto': The lower bound is automatically set at or below the minimum
    #   of the data.
    # * 'track': The lower bound tracks the upper bound by **tracking_amount**.
    # * CFloat: An explicit value for the lower bound
    low_setting = Property(Trait('auto', 'auto', 'track', CFloat))
    # Property for the upper bound of this range (overrides AbstractDataRange).
    #
    # * 'auto': The upper bound is automatically set at or above the maximum
    #   of the data.
    # * 'track': The upper bound tracks the lower bound by **tracking_amount**.
    # * CFloat: An explicit value for the upper bound
    high_setting = Property(Trait('auto', 'auto', 'track', CFloat))

    # Do "auto" bounds imply an exact fit to the data? If False,
    # they pad a little bit of margin on either side.
    tight_bounds = Bool(True)

    # A user supplied function returning the proper bounding interval.
    # bounds_func takes (data_low, data_high, margin, tight_bounds)
    # and returns (low, high)
    bounds_func = Callable

    # The amount of margin to place on either side of the data, expressed as
    # a percentage of the full data width
    margin = Float(0.05)

    # The minimum percentage difference between low and high.  That is,
    # (high-low) >= epsilon * low.
    # Used to be 1.0e-20 but chaco cannot plot at such a precision!
    epsilon = CFloat(1.0e-10)

    # When either **high** or **low** tracks the other, track by this amount.
    default_tracking_amount = CFloat(20.0)

    # The current tracking amount. This value changes with zooming.
    tracking_amount = default_tracking_amount

    # Default tracking state. This value is used when self.reset() is called.
    #
    # * 'auto': Both bounds reset to 'auto'.
    # * 'high_track': The high bound resets to 'track', and the low bound
    #   resets to 'auto'.
    # * 'low_track': The low bound resets to 'track', and the high bound
    #   resets to 'auto'.
    default_state = Enum('auto', 'high_track', 'low_track')

    # FIXME: this attribute is not used anywhere, is it safe to remove it?
    # Is this range dependent upon another range?
    fit_to_subset = Bool(False)

    #------------------------------------------------------------------------
    # Private traits
    #------------------------------------------------------------------------

    # The "_setting" attributes correspond to what the user has "set"; the
    # "_value" attributes are the actual numerical values for the given
    # setting.

    # The user-specified low setting.
    _low_setting = Trait('auto', 'auto', 'track', CFloat)
    # The actual numerical value for the low setting.
    _low_value = CFloat(-inf)
    # The user-specified high setting.
    _high_setting = Trait('auto', 'auto', 'track', CFloat)
    # The actual numerical value for the high setting.
    _high_value = CFloat(inf)

    # A list of attributes to persist
    # _pickle_attribs = ("_low_setting", "_high_setting")

    #------------------------------------------------------------------------
    # AbstractRange interface
    #------------------------------------------------------------------------

    def clip_data(self, data):
        """ Returns a list of data values that are within the range.

        Implements AbstractDataRange.
        """
        return compress(self.mask_data(data), data)

    def mask_data(self, data):
        """ Returns a mask array, indicating whether values in the given array
        are inside the range.

        Implements AbstractDataRange.
        """
        return ((data.view(ndarray) >= self._low_value) &
                (data.view(ndarray) <= self._high_value))

    def bound_data(self, data):
        """ Returns a tuple of indices for the start and end of the first run
        of *data* that falls within the range.

        Implements AbstractDataRange.
        """
        mask = self.mask_data(data)
        runs = arg_find_runs(mask, "flat")
        # Since runs of "0" are also considered runs, we have to cycle through
        # until we find the first run of "1"s.
        for run in runs:
            if mask[run[0]] == 1:
                # arg_find_runs returns 1 past the end
                return run[0], run[1] - 1
        return (0, 0)

    def set_bounds(self, low, high):
        """ Sets all the bounds of the range simultaneously.

        Implements AbstractDataRange.
        """
        if low == 'track':
            # Set the high setting first
            result_high = self._do_set_high_setting(high, fire_event=False)
            result_low = self._do_set_low_setting(low, fire_event=False)
            result = result_low or result_high
        else:
            # Either set low first or order doesn't matter
            result_low = self._do_set_low_setting(low, fire_event=False)
            result_high = self._do_set_high_setting(high, fire_event=False)
            result = result_high or result_low
        if result:
            self.updated = result

    def scale_tracking_amount(self, multiplier):
        """ Sets the **tracking_amount** to a new value, scaled by *multiplier*.
        """
        self.tracking_amount = self.tracking_amount * multiplier
        self._do_track()

    def set_tracking_amount(self, amount):
        """ Sets the **tracking_amount** to a new value, *amount*.
        """
        self.tracking_amount = amount
        self._do_track()

    def set_default_tracking_amount(self, amount):
        """ Sets the **default_tracking_amount** to a new value, *amount*.
        """
        self.default_tracking_amount = amount

    #------------------------------------------------------------------------
    # Public methods
    #------------------------------------------------------------------------

    def reset(self):
        """ Resets the bounds of this range, based on **default_state**.
        """
        # need to maintain 'track' setting
        if self.default_state == 'auto':
            self._high_setting = 'auto'
            self._low_setting = 'auto'
        elif self.default_state == 'low_track':
            self._high_setting = 'auto'
            self._low_setting = 'track'
        elif self.default_state == 'high_track':
            self._high_setting = 'track'
            self._low_setting = 'auto'
        self._refresh_bounds()
        self.tracking_amount = self.default_tracking_amount

    def refresh(self):
        """ If any of the bounds is 'auto', this method refreshes the actual
        low and high values from the set of the view filters' data sources.
        """
        if ('auto' in (self._low_setting, self._high_setting)) or \
            ('track' in (self._low_setting, self._high_setting)):
            # If the user has hard-coded bounds, then refresh() doesn't do
            # anything.
            self._refresh_bounds()
        else:
            return

    #------------------------------------------------------------------------
    # Private methods (getters and setters)
    #------------------------------------------------------------------------

    def _get_low(self):
        return float(self._low_value)

    def _set_low(self, val):
        return self._set_low_setting(val)

    def _get_low_setting(self):
        return self._low_setting

    def _do_set_low_setting(self, val, fire_event=True):
        """
        Returns
        -------
        If fire_event is False and the change would have fired an event, returns
        the tuple of the new low and high values.  Otherwise returns None.  In
        particular, if fire_event is True, it always returns None.
        """
        new_values = None
        if self._low_setting != val:

            # Save the new setting.
            self._low_setting = val

            # If val is 'auto' or 'track', get the corresponding numerical
            # value.
            if val == 'auto':
                if len(self.sources) > 0:
                    val = min(
                        [source.get_bounds()[0] for source in self.sources])
                else:
                    val = -inf
            elif val == 'track':
                if len(self.sources) > 0 or self._high_setting != 'auto':
                    val = self._high_value - self.tracking_amount
                else:
                    val = -inf

            # val is now a numerical value.  If it is the same as the current
            # value, there is nothing to do.
            if self._low_value != val:
                self._low_value = val
                if self._high_setting == 'track':
                    self._high_value = val + self.tracking_amount
                if fire_event:
                    self.updated = (self._low_value, self._high_value)
                else:
                    new_values = (self._low_value, self._high_value)

        return new_values

    def _set_low_setting(self, val):
        self._do_set_low_setting(val, True)

    def _get_high(self):
        return float(self._high_value)

    def _set_high(self, val):
        return self._set_high_setting(val)

    def _get_high_setting(self):
        return self._high_setting

    def _do_set_high_setting(self, val, fire_event=True):
        """
        Returns
        -------
        If fire_event is False and the change would have fired an event, returns
        the tuple of the new low and high values.  Otherwise returns None.  In
        particular, if fire_event is True, it always returns None.
        """
        new_values = None
        if self._high_setting != val:

            # Save the new setting.
            self._high_setting = val

            # If val is 'auto' or 'track', get the corresponding numerical
            # value.
            if val == 'auto':
                if len(self.sources) > 0:
                    val = max(
                        [source.get_bounds()[1] for source in self.sources])
                else:
                    val = inf
            elif val == 'track':
                if len(self.sources) > 0 or self._low_setting != 'auto':
                    val = self._low_value + self.tracking_amount
                else:
                    val = inf

            # val is now a numerical value.  If it is the same as the current
            # value, there is nothing to do.
            if self._high_value != val:
                self._high_value = val
                if self._low_setting == 'track':
                    self._low_value = val - self.tracking_amount
                if fire_event:
                    self.updated = (self._low_value, self._high_value)
                else:
                    new_values = (self._low_value, self._high_value)

        return new_values

    def _set_high_setting(self, val):
        self._do_set_high_setting(val, True)

    def _refresh_bounds(self):
        null_bounds = False
        if len(self.sources) == 0:
            null_bounds = True
        else:
            bounds_list = [source.get_bounds() for source in self.sources \
                              if source.get_size() > 0]

            if len(bounds_list) == 0:
                null_bounds = True

        if null_bounds:
            # If we have no sources and our settings are "auto", then reset our
            # bounds to infinity; otherwise, set the _value to the corresponding
            # setting.
            if (self._low_setting in ("auto", "track")):
                self._low_value = -inf
            else:
                self._low_value = self._low_setting
            if (self._high_setting in ("auto", "track")):
                self._high_value = inf
            else:
                self._high_value = self._high_setting
            return
        else:
            mins, maxes = list(zip(*bounds_list))

            low_start, high_start = \
                     calc_bounds(self._low_setting, self._high_setting,
                                 mins, maxes, self.epsilon,
                                 self.tight_bounds, margin=self.margin,
                                 track_amount=self.tracking_amount,
                                 bounds_func=self.bounds_func)

        if (self._low_value != low_start) or (self._high_value != high_start):
            self._low_value = low_start
            self._high_value = high_start
            self.updated = (self._low_value, self._high_value)
        return

    def _do_track(self):
        changed = False
        if self._low_setting == 'track':
            new_value = self._high_value - self.tracking_amount
            if self._low_value != new_value:
                self._low_value = new_value
                changed = True
        elif self._high_setting == 'track':
            new_value = self._low_value + self.tracking_amount
            if self._high_value != new_value:
                self._high_value = new_value
                changed = True
        if changed:
            self.updated = (self._low_value, self._high_value)

    #------------------------------------------------------------------------
    # Event handlers
    #------------------------------------------------------------------------

    def _sources_items_changed(self, event):
        self.refresh()
        for source in event.removed:
            source.on_trait_change(self.refresh, "data_changed", remove=True)
        for source in event.added:
            source.on_trait_change(self.refresh, "data_changed")

    def _sources_changed(self, old, new):
        self.refresh()
        for source in old:
            source.on_trait_change(self.refresh, "data_changed", remove=True)
        for source in new:
            source.on_trait_change(self.refresh, "data_changed")

    #------------------------------------------------------------------------
    # Serialization interface
    #------------------------------------------------------------------------

    def _post_load(self):
        self._sources_changed(None, self.sources)
Example #15
0
class WorkflowItem(HasStrictTraits):
    """        
    The basic unit of a Workflow: wraps an operation and a list of views.
    
    Notes
    -----
    Because we serialize instances of this, we have to pay careful attention
    to which traits are ``transient`` (and aren't serialized)
    """

    # the operation's id
    friendly_id = DelegatesTo('operation')

    # the operation's name
    name = DelegatesTo('operation')

    # the operation this Item wraps
    operation = Instance(IOperation, copy="ref")

    # for the vertical notebook view, is this page deletable?
    deletable = Bool(True)

    # the handler that's associated with this operation; we get it from the
    # operation plugin, and it controls what operation traits are in the UI
    # and any special handling (heh) of them.  since the handler doesn't
    # maintain any state, we can make and destroy as needed.
    operation_handler = Property(depends_on='operation',
                                 trait=Instance(Handler),
                                 transient=True)

    operation_traits = View(
        Item('operation_handler', style='custom', show_label=False))

    # the Experiment that is the result of applying *operation* to the
    # previous_wi WorkflowItem's ``result``
    result = Instance(Experiment, transient=True)

    # the channels, conditions and statistics from result.  usually these would be
    # Properties (ie, determined dynamically), but that's hard with the
    # multiprocess model.

    channels = List(Str, status=True)
    conditions = Dict(Str, pd.Series, status=True)
    metadata = Dict(Str, Any, status=True)
    statistics = Dict(Tuple(Str, Str), pd.Series, status=True)

    # the IViews associated with this operation
    views = List(IView, copy="ref")

    # the currently selected view
    current_view = Instance(IView, copy="ref")

    # the handler for the currently selected view
    current_view_handler = Property(depends_on='current_view',
                                    trait=Instance(Handler),
                                    transient=True)

    current_view_traits = View(
        Item('current_view_handler', style='custom', show_label=False))

    # the view for the plot params
    plot_params_traits = View(
        Item('current_view_handler',
             editor=InstanceEditor(view='plot_params_traits'),
             style='custom',
             show_label=False))

    # the view for the current plot
    current_plot_view = View(
        Item('current_view_handler',
             editor=InstanceEditor(view='current_plot_view'),
             style='custom',
             show_label=False))

    # the default view for this workflow item
    default_view = Instance(IView, copy="ref")

    # the previous_wi WorkflowItem in the workflow
    previous_wi = Instance('WorkflowItem', transient=True)

    # the next_wi WorkflowItem in the workflow
    next_wi = Instance('WorkflowItem', transient=True)

    # is the wi valid?
    # MAGIC: first value is the default
    status = Enum("invalid",
                  "estimating",
                  "applying",
                  "valid",
                  "loading",
                  status=True)

    # report the errors and warnings
    op_error = Str(status=True)
    op_error_trait = Str(status=True)
    op_warning = Str(status=True)
    op_warning_trait = Str(status=True)
    estimate_error = Str(status=True)
    estimate_warning = Str(status=True)
    view_error = Str(status=True)
    view_error_trait = Str(status=True)
    view_warning = Str(status=True)
    view_warning_trait = Str(status=True)

    # the central event to kick of WorkflowItem update logic
    changed = Event

    # the icon for the vertical notebook view.  Qt specific, sadly.
    icon = Property(depends_on='status', transient=True)

    # synchronization primitive for updating wi traits
    lock = Instance(threading.RLock, (), transient=True)

    # synchronization primitives for plotting
    matplotlib_events = Any(transient=True)
    plot_lock = Any(transient=True)

    # events to track number of times apply() and plot() are called
    apply_called = Event
    plot_called = Event

    @cached_property
    def _get_icon(self):
        if self.status == "valid":
            return QtGui.QStyle.SP_DialogApplyButton  # @UndefinedVariable
        elif self.status == "estimating" or self.status == "applying":
            return QtGui.QStyle.SP_BrowserReload  # @UndefinedVariable
        else:  # self.valid == "invalid" or None
            return QtGui.QStyle.SP_DialogCancelButton  # @UndefinedVariable

    @cached_property
    def _get_operation_handler(self):
        return self.operation.handler_factory(model=self.operation,
                                              context=self)

    @cached_property
    def _get_current_view_handler(self):
        if self.current_view:
            return self.current_view.handler_factory(model=self.current_view,
                                                     context=self)
        else:
            return None

    def __str__(self):
        return "<{}: {}>".format(self.__class__.__name__,
                                 self.operation.__class__.__name__)

    def __repr__(self):
        return "<{}: {}>".format(self.__class__.__name__,
                                 self.operation.__class__.__name__)
"""
 Defines the class that describes the information on the inputs and
 outputs of an object in the pipeline.
"""
# Author: Prabhu Ramachandran <*****@*****.**>
# Copyright (c) 2008-2016, Prabhu Ramachandran Enthought, Inc.
# License: BSD Style.

# Enthought library imports.
from traits.api import HasTraits, Enum, List

# The list of datasets supported.
DataSet = Enum('none', 'any', 'image_data', 'rectilinear_grid', 'poly_data',
               'structured_grid', 'unstructured_grid')

# Attribute type.
AttributeType = Enum('any', 'cell', 'point', 'none')

# Attribute.
Attribute = Enum('any', 'none', 'scalars', 'vectors', 'tensors')


################################################################################
# Utility functions.
################################################################################
def get_tvtk_dataset_name(dataset):
    """Given a TVTK dataset `dataset` return the string dataset type of
    the dataset.
    """
    result = 'none'
    if hasattr(dataset, 'is_a'):
class Calib_Params(HasTraits):

    #general and unsed variables
    pair_enable_flag = Bool(True)
    n_img = Int()
    img_name = []
    img_cal = []
    hp_flag = Bool()
    allCam_flag = Bool()
    mmp_n1 = Float()
    mmp_n2 = Float()
    mmp_n3 = Float()
    mmp_d = Float()

    #images data
    cam_1 = Str("", label='Calibration picture camera 1')
    cam_2 = Str("", label='Calibration picture camera 2')
    cam_3 = Str("", label='Calibration picture camera 3')
    cam_4 = Str("", label='Calibration picture camera 4')
    ori_cam_1 = Str("", label='Orientation data picture camera 1')
    ori_cam_2 = Str("", label='Orientation data picture camera 2')
    ori_cam_3 = Str("", label='Orientation data picture camera 3')
    ori_cam_4 = Str("", label='Orientation data picture camera 4')

    fixp_name = Str("", label='File of Coordinates on plate')
    tiff_head = Bool(True, label='TIFF-Header')
    pair_head = Bool(True, label='Include pairs')
    chfield = Enum("Frame", "Field odd", "Field even")

    Group1_1 = Group(Item(name='cam_1'),
                     Item(name='cam_2'),
                     Item(name='cam_3'),
                     Item(name='cam_4'),
                     label='Calibration pictures',
                     show_border=True)
    Group1_2 = Group(Item(name='ori_cam_1'),
                     Item(name='ori_cam_2'),
                     Item(name='ori_cam_3'),
                     Item(name='ori_cam_4'),
                     label='Orientation data',
                     show_border=True)
    Group1_3 = Group(Item(name='fixp_name'),
                     Group(Item(name='tiff_head'),
                           Item(name='pair_head',
                                enabled_when='pair_enable_flag'),
                           Item(name='chfield',
                                show_label=False,
                                style='custom'),
                           orientation='vertical',
                           columns=3),
                     orientation='vertical')

    # Group 1 is the group of General parameters
    # number of cameras, use only quadruplets or also triplets/pairs?
    # names of the test images, calibration files

    Group1 = Group(Group1_1,
                   Group1_2,
                   Group1_3,
                   orientation='vertical',
                   label='Images Data')

    #calibration data detection

    h_image_size = Int('', label='Image size horizontal')
    v_image_size = Int('', label='Image size vertical')
    h_pixel_size = Float('', label='Pixel size horizontal')
    v_pixel_size = Float('', label='Pixel size vertical')

    grey_value_treshold_1 = Int('', label='First Image')
    grey_value_treshold_2 = Int('', label='Second Image')
    grey_value_treshold_3 = Int('', label='Third Image')
    grey_value_treshold_4 = Int('', label='Forth Image')
    tolerable_discontinuity = Int('', label='Tolerable discontinuity')
    min_npix = Int('', label='min npix')
    min_npix_x = Int('', label='min npix in x')
    min_npix_y = Int('', label='min npix in y')
    max_npix = Int('', label='max npix')
    max_npix_x = Int('', label='max npix in x')
    max_npix_y = Int('', label='max npix in y')
    sum_of_grey = Int('', label='Sum of greyvalue')
    size_of_crosses = Int('', label='Size of crosses')

    Group2_1 = Group(Item(name='h_image_size'),
                     Item(name='v_image_size'),
                     Item(name='h_pixel_size'),
                     Item(name='v_pixel_size'),
                     label='Image properties',
                     show_border=True,
                     orientation='horizontal')

    Group2_2 = Group(Item(name='grey_value_treshold_1'),
                     Item(name='grey_value_treshold_2'),
                     Item(name='grey_value_treshold_3'),
                     Item(name='grey_value_treshold_4'),
                     orientation='horizontal',
                     label='Grayvalue threshold',
                     show_border=True),

    Group2_3 = Group(Group(Item(name='min_npix'),
                           Item(name='min_npix_x'),
                           Item(name='min_npix_y'),
                           orientation='vertical'),
                     Group(Item(name='max_npix'),
                           Item(name='max_npix_x'),
                           Item(name='max_npix_y'),
                           orientation='vertical'),
                     Group(Item(name='tolerable_discontinuity'),
                           Item(name='sum_of_grey'),
                           Item(name='size_of_crosses'),
                           orientation='vertical'),
                     orientation='horizontal')

    Group2 = Group(Group2_1,
                   Group2_2,
                   Group2_3,
                   orientation='vertical',
                   label='Calibration Data Detection')

    #manuel pre orientation
    img_1_p1 = Int('', label='P1')
    img_1_p2 = Int('', label='P2')
    img_1_p3 = Int('', label='P3')
    img_1_p4 = Int('', label='P4')
    img_2_p1 = Int('', label='P1')
    img_2_p2 = Int('', label='P2')
    img_2_p3 = Int('', label='P3')
    img_2_p4 = Int('', label='P4')
    img_3_p1 = Int('', label='P1')
    img_3_p2 = Int('', label='P2')
    img_3_p3 = Int('', label='P3')
    img_3_p4 = Int('', label='P4')
    img_4_p1 = Int('', label='P1')
    img_4_p2 = Int('', label='P2')
    img_4_p3 = Int('', label='P3')
    img_4_p4 = Int('', label='P4')

    Group3_1 = Group(Item(name='img_1_p1'),
                     Item(name='img_1_p2'),
                     Item(name='img_1_p3'),
                     Item(name='img_1_p4'),
                     orientation='horizontal',
                     label='Image 1',
                     show_border=True)
    Group3_2 = Group(Item(name='img_2_p1'),
                     Item(name='img_2_p2'),
                     Item(name='img_2_p3'),
                     Item(name='img_2_p4'),
                     orientation='horizontal',
                     label='Image 2',
                     show_border=True)
    Group3_3 = Group(Item(name='img_3_p1'),
                     Item(name='img_3_p2'),
                     Item(name='img_3_p3'),
                     Item(name='img_3_p4'),
                     orientation='horizontal',
                     label='Image 3',
                     show_border=True)
    Group3_4 = Group(Item(name='img_4_p1'),
                     Item(name='img_4_p2'),
                     Item(name='img_4_p3'),
                     Item(name='img_4_p4'),
                     orientation='horizontal',
                     label='Image 4',
                     show_border=True)
    Group3 = Group(Group3_1,
                   Group3_2,
                   Group3_3,
                   Group3_4,
                   show_border=True,
                   label='Manual pre-orientation')

    #calibration orientation param.

    Examine_Flag = Bool('', label='Calibrate with different Z')
    Combine_Flag = Bool('', label='Combine preprocessed planes')

    point_number_of_orientation = Int('', label='Point number of orientation')
    principle_distance = Bool(False, label='Princple distance')
    xp = Bool(False, label='xp')
    yp = Bool(False, label='yp')
    k1 = Bool(False, label='K1')
    k2 = Bool(False, label='K2')
    k3 = Bool(False, label='K3')
    p1 = Bool(False, label='P1')
    p2 = Bool(False, label='P2')
    scx = Bool(False, label='scx')
    she = Bool(False, label='she')
    interf = Bool(False, label='interfaces check box are available')

    Group4_0 = Group(Item(name='Examine_Flag'),
                     Item(name='Combine_Flag'),
                     show_border=True)

    Group4_1 = Group(Item(name='principle_distance'),
                     Item(name='xp'),
                     Item(name='yp'),
                     orientation='vertical',
                     columns=3)
    Group4_2 = Group(Item(name='k1'),
                     Item(name='k2'),
                     Item(name='k3'),
                     Item(name='p1'),
                     Item(name='p2'),
                     orientation='vertical',
                     columns=5,
                     label='Lens distortion(Brown)',
                     show_border=True)
    Group4_3 = Group(Item(name='scx'),
                     Item(name='she'),
                     orientation='vertical',
                     columns=2,
                     label='Affin transformation',
                     show_border=True)
    Group4_4 = Group(Item(name='interf'))

    Group4 = Group(Group(Group4_0,
                         Item(name='point_number_of_orientation'),
                         Group4_1,
                         Group4_2,
                         Group4_3,
                         Group4_4,
                         label=' Orientation Parameters ',
                         show_border=True),
                   orientation='horizontal',
                   show_border=True,
                   label='Calibration Orientation Param.')

    #dumbbell parameters
    #5  eps (mm)
    #46.5 dumbbell scale
    #0.005 gradient descent factor
    #1 weight for dumbbell penalty
    #2 step size through sequence
    #500 num iterations per click

    dumbbell_eps = Float('', label='dumbbell epsilon')
    dumbbell_scale = Float('', label='dumbbell scale')
    dumbbell_gradient_descent = Float('',
                                      label='dumbbell gradient descent factor')
    dumbbell_penalty_weight = Float('', label='weight for dumbbell penalty')
    dumbbell_step = Int('', label='step size through sequence')
    dumbbell_niter = Int('', label='number of iterations per click')

    Group5 = HGroup(VGroup(Item(name='dumbbell_eps'),
                           Item(name='dumbbell_scale'),
                           Item(name='dumbbell_gradient_descent'),
                           Item(name='dumbbell_penalty_weight'),
                           Item(name='dumbbell_step'),
                           Item(name='dumbbell_niter')),
                    spring,
                    label='Dumbbell calibration parameters',
                    show_border=True)

    # shaking parameters
    # 10000 - first frame
    # 10004 - last frame
    # 10 - max num points used per frame
    # 5 - max number of frames to track

    shaking_first_frame = Int('', label='shaking first frame')
    shaking_last_frame = Int('', label='shaking last frame')
    shaking_max_num_points = Int('', label='shaking max num points')
    shaking_max_num_frames = Int('', label='shaking max num frames')

    Group6 = HGroup(VGroup(Item(name='shaking_first_frame', ),
                           Item(name='shaking_last_frame'),
                           Item(name='shaking_max_num_points'),
                           Item(name='shaking_max_num_frames')),
                    spring,
                    label='Shaking calibration parameters',
                    show_border=True)

    Calib_Params_View = View(Tabbed(Group1, Group2, Group3, Group4, Group5,
                                    Group6),
                             buttons=['Undo', 'OK', 'Cancel'],
                             handler=CalHandler(),
                             title='Calibration Parameters')

    def _reload(self):
        #print("raloading")
        #self.__init__(self)
        #load ptv_par
        ptvParams = par.PtvParams(path=self.par_path)
        ptvParams.read()
        (n_img, img_name, img_cal, hp_flag, allCam_flag, tiff_flag, imx, imy, pix_x, pix_y, chfield, mmp_n1, mmp_n2, mmp_n3, mmp_d) = \
         (ptvParams.n_img, ptvParams.img_name, ptvParams.img_cal, ptvParams.hp_flag, ptvParams.allCam_flag, ptvParams.tiff_flag, \
         ptvParams.imx, ptvParams.imy, ptvParams.pix_x, ptvParams.pix_y, ptvParams.chfield, ptvParams.mmp_n1, ptvParams.mmp_n2, ptvParams.mmp_n3, ptvParams.mmp_d)

        #read picture size parameters

        self.h_image_size = imx
        self.v_image_size = imy
        self.h_pixel_size = pix_x
        self.v_pixel_size = pix_y
        self.img_cal = img_cal
        if allCam_flag:
            self.pair_enable_flag = False
        else:
            self.pair_enable_flag = True

        #unesed parameters

        self.n_img = n_img
        self.img_name = img_name
        self.hp_flag = n.bool(hp_flag)
        self.allCam_flag = n.bool(allCam_flag)
        self.mmp_n1 = mmp_n1
        self.mmp_n2 = mmp_n2
        self.mmp_n3 = mmp_n3
        self.mmp_d = mmp_d

        #read_calibration parameters
        calOriParams = par.CalOriParams(n_img, path=self.par_path)
        calOriParams.read()
        (fixp_name, img_cal_name, img_ori, tiff_flag, pair_flag, chfield) = \
         (calOriParams.fixp_name, calOriParams.img_cal_name, calOriParams.img_ori, calOriParams.tiff_flag, calOriParams.pair_flag, calOriParams.chfield)

        self.cam_1 = img_cal_name[0]
        self.cam_2 = img_cal_name[1]
        self.cam_3 = img_cal_name[2]
        self.cam_4 = img_cal_name[3]
        self.ori_cam_1 = img_ori[0]
        self.ori_cam_2 = img_ori[1]
        self.ori_cam_3 = img_ori[2]
        self.ori_cam_4 = img_ori[3]
        self.tiff_head = n.bool(tiff_flag)
        self.pair_head = n.bool(pair_flag)
        self.fixp_name = fixp_name
        if chfield == 0:
            self.chfield = "Frame"
        elif chfield == 1:
            self.chfield = "Field odd"
        else:
            self.chfield = "Field even"

        #read detect plate parameters
        detectPlateParams = par.DetectPlateParams(path=self.par_path)
        detectPlateParams.read()

        (gv_th1, gv_th2, gv_th3, gv_th4,tolerable_discontinuity, min_npix, max_npix, min_npix_x, \
         max_npix_x, min_npix_y, max_npix_y, sum_of_grey, size_of_crosses) = \
         (detectPlateParams.gvth_1, detectPlateParams.gvth_2, detectPlateParams.gvth_3, detectPlateParams.gvth_4, \
         detectPlateParams.tol_dis, detectPlateParams.min_npix, detectPlateParams.max_npix, detectPlateParams.min_npix_x, \
         detectPlateParams.max_npix_x, detectPlateParams.min_npix_y, detectPlateParams.max_npix_y, detectPlateParams.sum_grey, \
         detectPlateParams.size_cross)

        self.grey_value_treshold_1 = gv_th1
        self.grey_value_treshold_2 = gv_th2
        self.grey_value_treshold_3 = gv_th3
        self.grey_value_treshold_4 = gv_th4
        self.tolerable_discontinuity = tolerable_discontinuity
        self.min_npix = min_npix
        self.min_npix_x = min_npix_x
        self.min_npix_y = min_npix_y
        self.max_npix = max_npix
        self.max_npix_x = max_npix_x
        self.max_npix_y = max_npix_y
        self.sum_of_grey = sum_of_grey
        self.size_of_crosses = size_of_crosses

        #read manual orientaion parameters
        manOriParams = par.ManOriParams(n_img, 4, path=self.par_path)
        manOriParams.read()
        nr = manOriParams.nr

        self.img_1_p1 = nr[0][0]
        self.img_1_p2 = nr[0][1]
        self.img_1_p3 = nr[0][2]
        self.img_1_p4 = nr[0][3]
        self.img_2_p1 = nr[1][0]
        self.img_2_p2 = nr[1][1]
        self.img_2_p3 = nr[1][2]
        self.img_2_p4 = nr[1][3]
        self.img_3_p1 = nr[2][0]
        self.img_3_p2 = nr[2][1]
        self.img_3_p3 = nr[2][2]
        self.img_3_p4 = nr[2][3]
        self.img_4_p1 = nr[3][0]
        self.img_4_p2 = nr[3][1]
        self.img_4_p3 = nr[3][2]
        self.img_4_p4 = nr[3][3]

        # examine arameters
        examineParams = par.ExamineParams(path=self.par_path)
        examineParams.read()
        (self.Examine_Flag, self.Combine_Flag) = (examineParams.Examine_Flag,
                                                  examineParams.Combine_Flag)

        # orientation parameters
        orientParams = par.OrientParams(path=self.par_path)
        orientParams.read()
        (po_num_of_ori, pri_dist, xp, yp, k1, k2, k3, p1, p2, scx, she, interf) = \
         (orientParams.pnfo, orientParams.prin_dis, orientParams.xp, orientParams.yp, orientParams.k1, orientParams.k2, orientParams.k3, \
         orientParams.p1, orientParams.p2, orientParams.scx, orientParams.she, orientParams.interf)

        self.point_number_of_orientation = po_num_of_ori
        self.principle_distance = n.bool(pri_dist)
        self.xp = n.bool(xp)
        self.yp = n.bool(yp)
        self.k1 = n.bool(k1)
        self.k2 = n.bool(k2)
        self.k3 = n.bool(k3)
        self.p1 = n.bool(p1)
        self.p2 = n.bool(p2)
        self.scx = n.bool(scx)
        self.she = n.bool(she)
        self.interf = n.bool(interf)

        dumbbellParams = par.DumbbellParams(path=self.par_path)
        dumbbellParams.read()
        (self.dumbbell_eps, self.dumbbell_scale, self.dumbbell_gradient_descent, \
         self.dumbbell_penalty_weight, self.dumbbell_step, self.dumbbell_niter) = \
         (dumbbellParams.dumbbell_eps, dumbbellParams.dumbbell_scale, \
         dumbbellParams.dumbbell_gradient_descent, dumbbellParams.dumbbell_penalty_weight, \
         dumbbellParams.dumbbell_step, dumbbellParams.dumbbell_niter)

        shakingParams = par.ShakingParams(path=self.par_path)
        shakingParams.read()
        (self.shaking_first_frame, self.shaking_last_frame, self.shaking_max_num_points, \
         self.shaking_max_num_frames) = (shakingParams.shaking_first_frame, shakingParams.shaking_last_frame, \
         shakingParams.shaking_max_num_points, shakingParams.shaking_max_num_frames)

    def __init__(self, par_path):
        self.par_path = par_path
        self._reload()
Example #18
0
class ToolkitEditorFactory ( EditorFactory ):
    """ Editor factory for tree editors.
    """
    #---------------------------------------------------------------------------
    #  Trait definitions:
    #---------------------------------------------------------------------------

    # Supported TreeNode objects
    nodes = List( TreeNode )

    # Mapping from TreeNode tuples to MultiTreeNodes
    multi_nodes = Dict

    # The column header labels if any.
    column_headers = List(Str)

    # Are the individual nodes editable?
    editable = Bool(True)

    # Selection mode.
    selection_mode = Enum('single', 'extended')

    # Is the editor shared across trees?
    shared_editor = Bool(False)

    # Reference to a shared object editor
    editor = Instance( EditorFactory )

    # The DockWindow graphical theme
    # FIXME: Implemented only in wx backend.
    dock_theme = Instance( DockWindowTheme )

    # Show icons for tree nodes?
    show_icons = Bool(True)

    # Hide the tree root node?
    hide_root = Bool(False)

    # Layout orientation of the tree and the editor
    orientation = Orientation

    # Number of tree levels (down from the root) that should be automatically
    # opened
    auto_open = Int

    # Size of the tree node icons
    # FIXME: Document as unimplemented or wx specific.
    icon_size = IconSize

    # Called when a node is selected
    on_select = Any

    # Called when a node is clicked
    on_click = Any

    # Called when a node is double-clicked
    on_dclick = Any

    # Called when a node is activated
    on_activated = Any

    # Call when the mouse hovers over a node
    on_hover = Any

    # The optional extended trait name of the trait to synchronize with the
    # editor's current selection:
    selected = Str

    # The optional extended trait name of the trait that should be assigned
    # a node object when a tree node is activated, by double-clicking or
    # pressing the Enter key when a node has focus (Note: if you want to
    # receive repeated activated events on the same node, make sure the trait
    # is defined as an Event):
    activated = Str

    # The optional extended trait name of the trait that should be assigned
    # a node object when a tree node is clicked on (Note: If you want to
    # receive repeated clicks on the same node, make sure the trait is defined
    # as an Event):
    click = Str

    # The optional extended trait name of the trait that should be assigned
    # a node object when a tree node is double-clicked on (Note: if you want to
    # receive repeated double-clicks on the same node, make sure the trait is
    # defined as an Event):
    dclick = Str

    # The optional extended trait name of the trait event that is fired
    # whenever the application wishes to veto a tree action in progress (e.g.
    # double-clicking a non-leaf tree node normally opens or closes the node,
    # but if you are handling the double-click event in your program, you may
    # wish to veto the open or close operation). Be sure to fire the veto event
    # in the event handler triggered by the operation (e.g. the 'dclick' event
    # handler.
    veto = Str

    # The optional extended trait name of the trait event that is fired when the
    # application wishes the currently visible portion of the tree widget to
    # repaint itself.
    refresh = Str

    # Mode for lines connecting tree nodes
    #
    # * 'appearance': Show lines only when they look good.
    # * 'on': Always show lines.
    # * 'off': Don't show lines.
    lines_mode = Enum ( 'appearance', 'on', 'off' )
    # FIXME: Document as unimplemented or wx specific.
    # Whether to alternate row colors or not.
    alternating_row_colors = Bool(False)

    # Any extra vertical padding to add.
    vertical_padding = Int(0)

    # Whether or not to expand on a double-click.
    expands_on_dclick = Bool(True)


    # Whether the labels should be wrapped around, if not an ellipsis is shown
    # This works only in the qt backend and if there is only one column in tree
    word_wrap = Bool(False)
Example #19
0
class GridPlane(Component):
    # The version of this class.  Used for persistence.
    __version__ = 0

    # The TVTK object that extracts the grid plane.  This is created
    # dynamically based on the input data type.
    plane = Instance(tvtk.Object)

    # The axis which is normal to the plane chosen.
    axis = Enum('x',
                'y',
                'z',
                desc='specifies the axis normal to the grid plane')

    # The position of the grid plane.
    position = Range(value=0,
                     low='_low',
                     high='_high',
                     enter_set=True,
                     auto_set=False)

    ########################################
    # Private traits.

    # Determines the lower limit of the position trait and is always 0.
    _low = Int(0)

    # Determines the upper limit of the position trait.  The value is
    # dynamically set depending on the input data and state of the
    # axis trait.  The default is some large value to avoid errors in
    # cases where the user may set the position before adding the
    # object to the mayavi tree.
    _high = Int(10000)

    ########################################
    # View related traits.

    # The View for this object.
    view = View(
        Group(Item(name='axis'), Item(name='position',
                                      enabled_when='_high > 0')))

    ######################################################################
    # `object` interface
    ######################################################################
    def __get_pure_state__(self):
        d = super(GridPlane, self).__get_pure_state__()
        # These traits are dynamically created.
        for name in ('plane', '_low', '_high'):
            d.pop(name, None)

        return d

    def __set_pure_state__(self, state):
        state_pickler.set_state(self, state)
        self._position_changed(self.position)

    ######################################################################
    # `Component` interface
    ######################################################################
    def setup_pipeline(self):
        """Override this method so that it *creates* its tvtk
        pipeline.

        This method is invoked when the object is initialized via
        `__init__`.  Note that at the time this method is called, the
        tvtk data pipeline will *not* yet be setup.  So upstream data
        will not be available.  The idea is that you simply create the
        basic objects and setup those parts of the pipeline not
        dependent on upstream sources and filters.
        """
        pass

    def update_pipeline(self):
        """Override this method so that it *updates* the tvtk pipeline
        when data upstream is known to have changed.

        This method is invoked (automatically) when the input fires a
        `pipeline_changed` event.
        """
        if len(self.inputs) == 0:
            return
        input = self.inputs[0].get_output_dataset()
        plane = None
        if input.is_a('vtkStructuredGrid'):
            plane = tvtk.StructuredGridGeometryFilter()
        elif input.is_a('vtkStructuredPoints') or input.is_a('vtkImageData'):
            plane = tvtk.ImageDataGeometryFilter()
        elif input.is_a('vtkRectilinearGrid'):
            plane = tvtk.RectilinearGridGeometryFilter()
        else:
            msg = "The GridPlane component does not support the %s dataset."\
                  %(input.class_name)
            error(msg)
            raise TypeError(msg)

        self.configure_connection(plane, self.inputs[0])
        self.plane = plane
        self.plane.update()
        self.outputs = [plane]
        self._update_limits()
        self._update_extents()
        # If the data is 2D make sure that we default to the
        # appropriate axis.
        extents = list(_get_extent(input))
        diff = [y - x for x, y in zip(extents[::2], extents[1::2])]
        if diff.count(0) > 0:
            self.axis = ['x', 'y', 'z'][diff.index(0)]

    def update_data(self):
        """Override this method to do what is necessary when upstream
        data changes.

        This method is invoked (automatically) when any of the inputs
        sends a `data_changed` event.
        """
        self._update_limits()
        self._update_extents()
        # Propagate the data_changed event.
        self.data_changed = True

    def has_output_port(self):
        """ The filter has an output port."""
        return True

    def get_output_object(self):
        """ Returns the output port."""
        return self.plane.output_port

    ######################################################################
    # Non-public methods.
    ######################################################################
    def _get_axis_index(self):
        return {'x': 0, 'y': 1, 'z': 2}[self.axis]

    def _update_extents(self):
        inp = self.plane.input
        extents = list(_get_extent(inp))
        pos = self.position
        axis = self._get_axis_index()
        extents[2 * axis] = pos
        extents[2 * axis + 1] = pos
        try:
            self.plane.set_extent(extents)
        except AttributeError:
            self.plane.extent = extents

    def _update_limits(self):
        extents = _get_extent(self.plane.input)
        axis = self._get_axis_index()
        pos = min(self.position, extents[2 * axis + 1])
        self._high = extents[2 * axis + 1]
        return pos

    def _axis_changed(self, val):
        if len(self.inputs) == 0:
            return
        pos = self._update_limits()
        if self.position == pos:
            self._update_extents()
            self.data_changed = True
        else:
            self.position = pos

    def _position_changed(self, val):
        if len(self.inputs) == 0:
            return
        self._update_extents()
        self.data_changed = True
Example #20
0
class VelocityView(HasTraits):
    python_console_cmds = Dict()
    plot = Instance(Plot)
    velocity_units = Enum(velocity_units_list)
    plot_data = Instance(ArrayPlotData)

    traits_view = View(
        VGroup(
            Spring(height=-2, springy=False),
            HGroup(
                Spring(width=-3, height=8, springy=False),
                Spring(springy=False, width=135),
                Item('velocity_units',
                     label="Display Units"),
                padding=0),
            Item(
                'plot',
                editor=ComponentEditor(bgcolor=(0.8, 0.8, 0.8)),
                label='Velocity',
                show_label=False),
        )
    )

    def _velocity_units_changed(self):
        if self.velocity_units == 'm/s':
            self.vel_sf = 1.0
        elif self.velocity_units == 'mph':
            self.vel_sf = MPS2MPH
        elif self.velocity_units == 'kph':
            self.vel_sf = MPS2KPH
        self.plot.value_axis.title = self.velocity_units

    def update_plot(self):
        self.last_plot_update_time = monotonic()
        self.plot_data.set_data('v_h', self.v_h * self.vel_sf)
        self.plot_data.set_data('v_z', self.v_z * self.vel_sf)
        if any(self.t == 0):
            pass
        else:
            self.plot_data.set_data('t', self.t)

    def vel_ned_callback(self, sbp_msg, **metadata):
        if sbp_msg.flags != 0:
            memoryview(self.v_h)[:-1] = memoryview(self.v_h)[1:]
            memoryview(self.v_z)[:-1] = memoryview(self.v_z)[1:]
            memoryview(self.t)[:-1] = memoryview(self.t)[1:]
            self.v_h[-1] = np.sqrt(sbp_msg.n * sbp_msg.n + sbp_msg.e * sbp_msg.e) / 1000.0
            self.v_z[-1] = -sbp_msg.d / 1000.0
            self.t[-1] = sbp_msg.tow / 1000.0

        if monotonic() - self.last_plot_update_time < GUI_UPDATE_PERIOD:
            return
        self.update_scheduler.schedule_update('update_plot', self.update_plot)

    def __init__(self, link):
        super(VelocityView, self).__init__()
        self.velocity_units = 'm/s'
        self.vel_sf = 1.0
        self.v_h = np.zeros(NUM_POINTS)
        self.v_z = np.zeros(NUM_POINTS)
        self.t = np.zeros(NUM_POINTS)

        self.last_plot_update_time = 0

        self.plot_data = ArrayPlotData(
            t=np.arange(NUM_POINTS),
            v_h=[0.0],
            v_z=[0.0]
        )
        self.plot = Plot(
            self.plot_data, auto_colors=colors_list, emphasized=True)
        self.plot.title_color = [0, 0, 0.43]
        self.plot.value_axis.orientation = 'right'
        self.plot.value_axis.axis_line_visible = False
        self.plot.value_axis.title = 'm/s'
        self.plot.value_axis.font = 'modern 8'
        self.plot.index_axis.title = 'GPS Time of Week'
        self.plot.index_axis.title_spacing = 40
        self.plot.index_axis.tick_label_font = 'modern 8'
        self.plot.value_axis.tick_color = 'gray'
        self.plot.index_axis.tick_label_rotate_angle = -45
        self.plot.title_visible = False
        self.legend_visible = True
        self.plot.legend.visible = True
        self.plot.legend.align = 'll'
        self.plot.legend.line_spacing = 1
        self.plot.legend.font = 'modern 8'
        self.plot.legend.draw_layer = 'overlay'
        self.plot.legend.tools.append(
            LegendTool(self.plot.legend, drag_button="right"))
        self.plot.padding_left = 35
        self.plot.padding_bottom = 60
        self.plot_paddint_top = -1
        self.plot.padding_right = 60

        self.plot.plot(
            ('t', 'v_h'), type='line', color='auto', name='Horizontal')
        self.plot.plot(
            ('t', 'v_z'), type='line', color='auto', name='Vertical')

        self.link = link
        self.link.add_callback(self.vel_ned_callback, SBP_MSG_VEL_NED)

        self.python_console_cmds = {'vel': self}

        self.update_scheduler = UpdateScheduler()
Example #21
0
class MATS3DScalarDamage(MATS3DEval):
    '''
    Scalar Damage Model.
    '''

    implements(IMATSEval)

    #---------------------------------------------------------------------------
    # Parameters of the numerical algorithm (integration)
    #---------------------------------------------------------------------------

    stiffness = Enum("algoritmic", "secant")

    #---------------------------------------------------------------------------
    # Material parameters
    #---------------------------------------------------------------------------

    E = Float(
        1.,  #34e+3,
        label="E",
        desc="Young's Modulus",
        auto_set=False)
    nu = Float(0.2, label='nu', desc="Poison's ratio", auto_set=False)
    epsilon_0 = Float(59e-6,
                      label="eps_0",
                      desc="Breaking Strain",
                      auto_set=False)

    epsilon_f = Float(191e-6,
                      label="eps_f",
                      desc="Shape Factor",
                      auto_set=False)

    stiffness = Enum("secant", "algorithmic")

    strain_norm = EitherType(
        klasses=[Energy, Euclidean, Mises, Rankine, Mazars])

    D_el = Property(Array(float), depends_on='E, nu')

    @cached_property
    def _get_D_el(self):
        return self._get_D_el()

    # This event can be used by the clients to trigger an action upon
    # the completed reconfiguration of the material model
    #
    changed = Event

    #--------------------------------------------------------------------------
    # View specification
    #--------------------------------------------------------------------------

    view_traits = View(VSplit(
        Group(Item('E'), Item('nu'), Item('strain_norm')),
        Group(
            Item('stiffness', style='custom'),
            Spring(resizable=True),
            label='Configuration parameters',
            show_border=True,
        ),
    ),
                       resizable=True)

    #--------------------------------------------------------------------------
    # Private initialization methods
    #--------------------------------------------------------------------------

    #--------------------------------------------------------------------------
    # Setup for computation within a supplied spatial context
    #--------------------------------------------------------------------------

    def get_state_array_size(self):
        '''
        Return number of number to be stored in state array
        @param sctx:spatial context
        '''
        return 2

    def setup(self, sctx):
        '''
        Intialize state variables.
        @param sctx:spatial context
        '''
        sctx.mats_state_array[:] = zeros(2, float_)
        #sctx.update_state_on = False

    def new_cntl_var(self):
        '''
        Return contoll variable array
        '''
        return zeros(6, float_)

    def new_resp_var(self):
        '''
        Return contoll response array
        '''
        return zeros(6, float_)

    #--------------------------------------------------------------------------
    # Evaluation - get the corrector and predictor
    #--------------------------------------------------------------------------

    def get_corr_pred(self, sctx, eps_app_eng, d_eps, tn, tn1, eps_avg=None):
        '''
        Corrector predictor computation.
        @param eps_app_eng input variable - engineering strain
        '''
        if eps_avg != None:
            pass
        else:
            eps_avg = eps_app_eng

        if sctx.update_state_on:
            #print "in us"
            eps_n = eps_avg - d_eps

            e_max, omega = self._get_state_variables(sctx, eps_n)

            sctx.mats_state_array[0] = e_max
            sctx.mats_state_array[1] = omega

        e_max, omega = self._get_state_variables(sctx, eps_app_eng)

        if self.stiffness == "algorithmic" and e_max > self.epsilon_0:
            D_e_dam = self._get_alg_stiffness(eps_app_eng, e_max, omega)
        else:
            D_e_dam = (1 - omega) * self.D_el

        sigma = dot(((1 - omega) * self.D_el), eps_app_eng)

        # You print the stress you just computed and the value of the apparent E

        return sigma, D_e_dam

    #--------------------------------------------------------------------------
    # Subsidiary methods realizing configurable features
    #--------------------------------------------------------------------------
    def _get_state_variables(self, sctx, eps_app_eng):
        e_max = sctx.mats_state_array[0]
        omega = sctx.mats_state_array[1]

        f_trial = self.strain_norm.get_f_trial(eps_app_eng, self.D_el, self.E,
                                               self.nu, e_max)
        if f_trial > 0:
            e_max += f_trial
            omega = self._get_omega(e_max)

        return e_max, omega

    def _get_D_el(self):
        '''
        Return elastic stiffness matrix
        '''
        D_el = zeros((6, 6))
        t2 = 1. / (1. + self.nu)
        t3 = self.E * t2
        t9 = self.E * self.nu * t2 / (1. - 2. * self.nu)
        t10 = t3 + t9
        t11 = t3 / 2.
        D_el[0][0] = t10
        D_el[0][1] = t9
        D_el[0][2] = t9
        D_el[0][3] = 0.
        D_el[0][4] = 0.
        D_el[0][5] = 0.
        D_el[1][0] = t9
        D_el[1][1] = t10
        D_el[1][2] = t9
        D_el[1][3] = 0.
        D_el[1][4] = 0.
        D_el[1][5] = 0.
        D_el[2][0] = t9
        D_el[2][1] = t9
        D_el[2][2] = t10
        D_el[2][3] = 0.
        D_el[2][4] = 0.
        D_el[2][5] = 0.
        D_el[3][0] = 0.
        D_el[3][1] = 0.
        D_el[3][2] = 0.
        D_el[3][3] = t11
        D_el[3][4] = 0.
        D_el[3][5] = 0.
        D_el[4][0] = 0.
        D_el[4][1] = 0.
        D_el[4][2] = 0.
        D_el[4][3] = 0.
        D_el[4][4] = t11
        D_el[4][5] = 0.
        D_el[5][0] = 0.
        D_el[5][1] = 0.
        D_el[5][2] = 0.
        D_el[5][3] = 0.
        D_el[5][4] = 0.
        D_el[5][5] = t11
        return D_el

    def _get_omega(self, kappa):
        '''
        Return new value of damage parameter
        @param kappa:
        '''
        epsilon_0 = self.epsilon_0
        epsilon_f = self.epsilon_f
        if kappa >= epsilon_0:
            return 1. - epsilon_0 / kappa * exp(
                -1 * (kappa - epsilon_0) / epsilon_f)
        else:
            return 0.

    def _get_alg_stiffness(self, eps_app_eng, e_max, omega):
        '''
        Return algorithmic stiffness matrix
        @param eps_app_eng:strain
        @param e_max:kappa
        @param omega:damage parameter
        '''
        epsilon_0 = self.epsilon_0
        epsilon_f = self.epsilon_f
        dodk = epsilon_0 / (e_max * e_max) * exp(-(e_max - epsilon_0) / epsilon_f) + \
                epsilon_0 / e_max / epsilon_f * exp(-(e_max - epsilon_0) / epsilon_f)
        dede = self.strain_norm.get_dede(eps_app_eng, self.D_el, self.E,
                                         self.nu)
        D_alg = (1 - omega) * self.D_el - \
                dot(dot(self.D_el, eps_app_eng), dede) * dodk
        return D_alg

    #--------------------------------------------------------------------------
    # Response trace evaluators
    #--------------------------------------------------------------------------

    def get_omega(self, sctx, eps_app_eng):
        '''
        Return damage parameter for RT
        @param sctx:spatial context
        @param eps_app_eng:actual strain
        '''
        return array([sctx.mats_state_array[1]])

    # Declare and fill-in the rte_dict - it is used by the clients to
    # assemble all the available time-steppers.
    #
    rte_dict = Trait(Dict)

    def _rte_dict_default(self):
        return {'sig_app': self.get_sig_app, 'omega': self.get_omega}
Example #22
0
class TimingStrategy(Strategy):
    SignalDelay = Int(0, label="信号滞后期", arg_type="Integer", order=1)
    SigalValidity = Int(1, label="信号有效期", arg_type="Integer", order=2)
    SigalDTs = List(label="信号触发时点", arg_type="DateTimeList", order=3)
    TargetAccount = Instance(Account,
                             label="目标账户",
                             arg_type="ArgObject",
                             order=4)
    ValueAllocated = Instance(pd.Series,
                              arg_type="Series",
                              label="资金分配",
                              order=5)
    TradeTarget = Enum("锁定买卖金额",
                       "锁定目标仓位",
                       "锁定目标金额",
                       label="交易目标",
                       arg_type="SingleOption",
                       order=6)

    def __init__(self,
                 name,
                 factor_table=None,
                 sys_args={},
                 config_file=None,
                 **kwargs):
        self._FT = factor_table  # 因子表
        self._AllAllocationReset = {}  # 存储所有的资金分配信号, {时点: 信号}
        self._ValueAllocated = None
        self._CashAllocated = None
        return super().__init__(name=name,
                                accounts=[],
                                fts=([] if self._FT is None else [self._FT]),
                                sys_args=sys_args,
                                config_file=config_file,
                                **kwargs)

    @on_trait_change("TargetAccount")
    def on_TargetAccount_changed(self, obj, name, old, new):
        if (self.TargetAccount is not None) and (self.TargetAccount
                                                 not in self.Accounts):
            self.Accounts.append(self.TargetAccount)
        elif (self.TargetAccount is None) and (old in self.Accounts):
            self.Accounts.remove(old)

    @on_trait_change("ValueAllocated")
    def on_ValueAllocated_changed(self, obj, name, old, new):
        self._isAllocationReseted = True

    @property
    def MainFactorTable(self):
        return self._FT

    @property
    def TargetIDs(self):  # 当前有分配资金的 ID 列表
        if self._ValueAllocated is None:
            if self.ValueAllocated is None:
                if self.TargetAccount is not None:
                    return self.TargetAccount.IDs
                else:
                    return []
            return self.ValueAllocated.index.tolist()
        else:
            return self._ValueAllocated[
                self._ValueAllocated != 0].index.tolist()

    @property
    def PositionLevel(self):  # 当前目标账户中所有 ID 的仓位水平
        if self.TargetAccount is None: raise __QS_Error__("尚未设置目标账户!")
        if self._CashAllocated is None:
            return pd.Series(0.0, index=self.TargetAccount.IDs)
        PositionAmount = self.TargetAccount.PositionAmount
        PositionValue = PositionAmount + self._CashAllocated
        PositionLevel = PositionAmount / PositionValue
        PositionLevel[PositionValue == 0] = 0.0
        Mask = ((PositionAmount != 0) & (PositionValue == 0))
        PositionLevel[Mask] = np.sign(PositionAmount)[Mask]
        return PositionLevel

    # 重新设置资金分配
    def _resetAllocation(self, new_allocation):
        IDs = self.TargetAccount.IDs
        if new_allocation is None:
            return pd.Series(self.TargetAccount.AccountValue / len(IDs),
                             index=IDs)
        elif new_allocation.index.intersection(IDs).shape[0] == 0:
            return pd.Series(0.0, index=IDs)
        else:
            return new_allocation.loc[IDs].fillna(0.0)

    def __QS_start__(self, mdl, dts, **kwargs):
        if self._isStarted: return ()
        Rslt = super().__QS_start__(mdl=mdl, dts=dts, **kwargs)
        self._TradeTarget = None  # 锁定的交易目标
        self._SignalExcutePeriod = 0  # 信号已经执行的期数
        self._ValueAllocated = self._resetAllocation(self.ValueAllocated)
        self._CashAllocated = self._ValueAllocated - self.TargetAccount.PositionAmount.fillna(
            0.0)
        self._AllAllocationReset = {
            dts[0] - dt.timedelta(1): self._ValueAllocated
        }
        self._isAllocationReseted = False
        # 初始化信号滞后发生的控制变量
        self._TempData = {}
        self._TempData['StoredSignal'] = []  # 暂存的信号, 用于滞后发出信号
        self._TempData['LagNum'] = []  # 当前时点距离信号触发时点的期数
        self._TempData['LastSignal'] = None  # 上次生成的信号
        self._TempData['StoredAllocation'] = []  # 暂存的资金分配信号, 用于滞后发出信号
        self._TempData['AllocationLagNum'] = []  # 当前时点距离信号触发时点的期数
        self._isStarted = True
        return (self._FT, ) + Rslt

    def __QS_move__(self, idt, **kwargs):
        if self._iDT == idt: return 0
        self._iDT = idt
        TradingRecord = {
            iAccount.Name: iAccount.__QS_move__(idt, **kwargs)
            for iAccount in self.Accounts
        }
        if (not self.SigalDTs) or (idt in self.SigalDTs):
            Signal = self.genSignal(idt, TradingRecord)
            if Signal is not None: self._AllSignals[idt] = Signal
        else: Signal = None
        Signal = self._bufferSignal(Signal)
        NewAllocation = None
        if self._isAllocationReseted:
            NewAllocation = self._resetAllocation(self.ValueAllocated)
            self._AllAllocationReset[idt] = NewAllocation
            self._isAllocationReseted = False
        NewAllocation = self._bufferAllocationReset(NewAllocation)
        if NewAllocation is not None:
            self._ValueAllocated = NewAllocation
            self._CashAllocated = self._ValueAllocated - self.TargetAccount.PositionAmount.fillna(
                0.0)
        else:  # 更新资金分配
            iTradingRecord = TradingRecord[self.TargetAccount.Name]
            if iTradingRecord.shape[0] > 0:
                CashChanged = pd.Series(
                    (iTradingRecord["买卖数量"] * iTradingRecord["价格"] +
                     iTradingRecord["交易费"]).values,
                    index=iTradingRecord["ID"].values)
                CashChanged = CashChanged.groupby(
                    axis=0, level=0).sum().loc[self._CashAllocated.index]
                self._CashAllocated -= CashChanged.fillna(0.0)
        self.trade(idt, TradingRecord, Signal)
        for iAccount in self.Accounts:
            iAccount.__QS_after_move__(idt, **kwargs)
        return 0

    def _output(self):
        Output = super()._output()
        Output["Strategy"]["择时信号"] = pd.DataFrame(self._AllSignals).T
        Output["Strategy"]["资金分配"] = pd.DataFrame(self._AllAllocationReset).T
        return Output

    def genSignal(self, idt, trading_record):
        return None

    def trade(self, idt, trading_record, signal):
        PositionAmount = self.TargetAccount.PositionAmount
        PositionValue = PositionAmount + self._CashAllocated
        if signal is not None:  # 有新的信号, 形成新的交易目标
            if signal.shape[0] > 0:
                signal = signal.loc[PositionValue.index]
            else:
                signal = pd.Series(np.nan, index=PositionValue.index)
            signal[self._ValueAllocated == 0] = 0.0
            if self.TradeTarget == "锁定买卖金额":
                self._TradeTarget = signal * PositionValue.abs(
                ) - PositionAmount
            elif self.TradeTarget == "锁定目标金额":
                self._TradeTarget = PositionValue.abs() * signal
            elif self.TradeTarget == "锁定目标仓位":
                self._TradeTarget = signal
            self._SignalExcutePeriod = 0
        elif self._TradeTarget is not None:  # 没有新的信号, 根据交易记录调整交易目标
            self._SignalExcutePeriod += 1
            if self._SignalExcutePeriod >= self.SigalValidity:
                self._TradeTarget = None
                self._SignalExcutePeriod = 0
            else:
                iTradingRecord = trading_record[self.TargetAccount.Name]
                if iTradingRecord.shape[0] > 0:
                    if self.TradeTarget == "锁定买卖金额":
                        TargetChanged = pd.Series(
                            (iTradingRecord["买卖数量"] *
                             iTradingRecord["价格"]).values,
                            index=iTradingRecord["ID"].values)
                        TargetChanged = TargetChanged.groupby(
                            axis=0, level=0).sum().loc[self._TradeTarget.index]
                        TargetChanged.fillna(0.0, inplace=True)
                        TradeTarget = self._TradeTarget - TargetChanged
                        TradeTarget[np.sign(self._TradeTarget) *
                                    np.sign(TradeTarget) < 0] = 0.0
                        self._TradeTarget = TradeTarget
        # 根据交易目标下订单
        if self._TradeTarget is not None:
            if self.TradeTarget == "锁定买卖金额":
                Orders = self._TradeTarget
            elif self.TradeTarget == "锁定目标仓位":
                Orders = self._TradeTarget * PositionValue.abs(
                ) - PositionAmount
            elif self.TradeTarget == "锁定目标金额":
                Orders = self._TradeTarget - PositionAmount
            Orders = Orders / self.TargetAccount.LastPrice
            Orders = Orders[pd.notnull(Orders) & (Orders != 0)]
            if Orders.shape[0] == 0: return 0
            Orders = pd.DataFrame(Orders.values,
                                  index=Orders.index,
                                  columns=["数量"])
            Orders["目标价"] = np.nan
            self.TargetAccount.order(combined_order=Orders)
        return 0

    # 将信号缓存, 并弹出滞后期到期的信号
    def _bufferSignal(self, signal):
        if self.SignalDelay <= 0: return signal
        if signal is not None:
            self._TempData['StoredSignal'].append(signal)
            self._TempData['LagNum'].append(-1)
        for i, iLagNum in enumerate(self._TempData['LagNum']):
            self._TempData['LagNum'][i] = iLagNum + 1
        signal = None
        while self._TempData['StoredSignal'] != []:
            if self._TempData['LagNum'][0] >= self.SignalDelay:
                signal = self._TempData['StoredSignal'].pop(0)
                self._TempData['LagNum'].pop(0)
            else:
                break
        return signal

    # 将资金分配信号缓存, 并弹出滞后期到期的资金分配信号
    def _bufferAllocationReset(self, allocation):
        if self.SignalDelay <= 0: return allocation
        if allocation is not None:
            self._TempData['StoredAllocation'].append(allocation)
            self._TempData['AllocationLagNum'].append(-1)
        for i, iLagNum in enumerate(self._TempData['AllocationLagNum']):
            self._TempData['AllocationLagNum'][i] = iLagNum + 1
        allocation = None
        while self._TempData['StoredAllocation'] != []:
            if self._TempData['AllocationLagNum'][0] >= self.SignalDelay:
                allocation = self._TempData['StoredAllocation'].pop(0)
                self._TempData['AllocationLagNum'].pop(0)
            else:
                break
        return allocation
Example #23
0
class JitterPlot(AbstractPlotRenderer):
    """A renderer for a jitter plot, a 1D plot with some width in the
    dimension perpendicular to the primary axis.  Useful for understanding
    dense collections of points.
    """

    # The data source of values
    index = Instance(ArrayDataSource)

    # The single mapper that this plot uses
    mapper = Instance(AbstractMapper)

    # Just an alias for "mapper"
    index_mapper = Property(lambda obj, attr: getattr(obj, "mapper"),
                            lambda obj, attr, val: setattr(obj, "mapper", val))

    x_mapper = Property()
    y_mapper = Property()

    orientation = Enum("h", "v")

    # The size, in pixels, of the area over which to spread the data points
    # along the dimension orthogonal to the index direction.
    jitter_width = Int(50)

    # How the plot should center itself along the orthogonal dimension if the
    # component's width is greater than the jitter_width
    #align = Enum("center", "left", "right", "top", "bottom")

    # The type of marker to use.  This is a mapped trait using strings as the
    # keys.
    marker = MarkerTrait

    # The pixel size of the marker, not including the thickness of the outline.
    marker_size = Float(4.0)

    # The CompiledPath to use if **marker** is set to "custom". This attribute
    # must be a compiled path for the Kiva context onto which this plot will
    # be rendered.  Usually, importing kiva.GraphicsContext will do
    # the right thing.
    custom_symbol = Any

    # The function which actually renders the markers
    render_markers_func = Callable(render_markers)

    # The thickness, in pixels, of the outline to draw around the marker.  If
    # this is 0, no outline is drawn.
    line_width = Float(1.0)

    # The fill color of the marker.
    color = black_color_trait

    # The color of the outline to draw around the marker.
    outline_color = black_color_trait

    # Override the base class default for **origin**, which specifies corners.
    # Since this is a 1D plot, it only makes sense to have the origin at the
    # edges.
    origin = Enum("bottom", "top", "left", "right")

    #------------------------------------------------------------------------
    # Built-in selection handling
    #------------------------------------------------------------------------

    # The name of the metadata attribute to look for on the datasource for
    # determine which points are selected and which are not.  The metadata
    # value returned should be a *list* of numpy arrays suitable for masking
    # the values returned by index.get_data().
    selection_metadata_name = Str("selections")

    # The color to use to render selected points
    selected_color = black_color_trait

    # Alpha value to apply to points that are not in the set of "selected"
    # points
    unselected_alpha = Float(0.3)
    unselected_line_width = Float(0.0)

    #------------------------------------------------------------------------
    # Private traits
    #------------------------------------------------------------------------

    _cache_valid = Bool(False)

    _cached_data_pts = Any()
    _cached_data_pts_sorted = Any()
    _cached_data_argsort = Any()

    _screen_cache_valid = Bool(False)
    _cached_screen_pts = Any()
    _cached_screen_map = Any()  # dict mapping index to value points

    # The random number seed used to generate the jitter.  We store this
    # so that the jittering is stable as the data is replotted.
    _jitter_seed = Trait(None, None, Int)

    #------------------------------------------------------------------------
    # Component/AbstractPlotRenderer interface
    #------------------------------------------------------------------------

    def map_screen(self, data_array):
        """ Maps an array of data points into screen space and returns it as
        an array.  Although the orthogonal (non-scaled) axis does not have
        a mapper, this method returns the scattered values in that dimension.

        Implements the AbstractPlotRenderer interface.
        """
        if len(data_array) == 0:
            return np.zeros(0)

        if self._screen_cache_valid:
            sm = self._cached_screen_map
            new_x = [x for x in data_array if x not in sm]
            if new_x:
                new_y = self._make_jitter_vals(len(new_x))
                sm.update(dict(
                    (new_x[i], new_y[i]) for i in range(len(new_x))))
            xs = self.mapper.map_screen(data_array)
            ys = [sm[x] for x in xs]

        else:
            if self._jitter_seed is None:
                self._set_seed(data_array)
            xs = self.mapper.map_screen(data_array)
            ys = self._make_jitter_vals(len(data_array))

        if self.orientation == "h":
            return np.vstack((xs, ys)).T
        else:
            return np.vstack((ys, xs)).T

    def _make_jitter_vals(self, numpts):
        vals = np.random.uniform(0, self.jitter_width, numpts)
        if self.orientation == "h":
            ymin = self.y
            height = self.height
            vals += ymin + height / 2 - self.jitter_width / 2
        else:
            xmin = self.x
            width = self.width
            vals += xmin + width / 2 - self.jitter_width / 2
        return vals

    def map_data(self, screen_pt):
        """ Maps a screen space point into the index space of the plot.
        """
        x, y = screen_pt
        if self.orientation == "v":
            x, y = y, x
        return self.mapper.map_data(x)

    def map_index(self, screen_pt, threshold=2.0, outside_returns_none=True, \
                  index_only = True):
        """ Maps a screen space point to an index into the plot's index array(s).
        """
        screen_points = self._cached_screen_pts

        if len(screen_points) == 0:
            return None

        data_pt = self.map_data(screen_pt)
        if ((data_pt < self.mapper.range.low) or \
            (data_pt > self.mapper.range.high)) and outside_returns_none:
            return None

        if self._cached_data_pts_sorted is None:
            self._cached_data_argsort = np.argsort(self._cached_data_pts)
            self._cached_data_pts_sorted = self._cached_data_pts[
                self._cached_data_argsort]

        data = self._cached_data_pts_sorted
        try:
            ndx = reverse_map_1d(data, data_pt, "ascending")
        except IndexError, e:
            if outside_returns_none:
                return None
            else:
                if data_pt < data[0]:
                    return 0
                else:
                    return len(data) - 1

        orig_ndx = self._cached_data_argsort[ndx]

        if threshold == 0.0:
            return orig_ndx

        sx, sy = screen_points[orig_ndx]
        if sqrt((screen_pt[0] - sx)**2 + (screen_pt[1] - sy)**2) <= threshold:
            return orig_ndx
        else:
            return None
Example #24
0
class FileDialog(MFileDialog, Dialog):
    """ The toolkit specific implementation of a FileDialog.  See the
    IFileDialog interface for the API documentation.
    """

    # 'IFileDialog' interface ---------------------------------------------#

    action = Enum("open", "open files", "save as")

    default_directory = Str()

    default_filename = Str()

    default_path = Str()

    directory = Str()

    filename = Str()

    path = Str()

    paths = List(Str)

    wildcard = Str()

    wildcard_index = Int(0)

    # ------------------------------------------------------------------------
    # Protected 'IDialog' interface.
    # ------------------------------------------------------------------------

    def _create_contents(self, parent):
        # In wx this is a canned dialog.
        pass

    # ------------------------------------------------------------------------
    # 'IWindow' interface.
    # ------------------------------------------------------------------------

    def close(self):
        # Get the path of the chosen directory.
        self.path = str(self.control.GetPath())
        # Work around wx bug throwing exception on cancel of file dialog
        if len(self.path) > 0:
            self.paths = self.control.GetPaths()
        else:
            self.paths = []

        # Extract the directory and filename.
        self.directory, self.filename = os.path.split(self.path)

        # Get the index of the selected filter.
        self.wildcard_index = self.control.GetFilterIndex()
        # Let the window close as normal.
        super(FileDialog, self).close()

    # ------------------------------------------------------------------------
    # Protected 'IWidget' interface.
    # ------------------------------------------------------------------------

    def _create_control(self, parent):
        # If the caller provided a default path instead of a default directory
        # and filename, split the path into it directory and filename
        # components.
        if (len(self.default_path) != 0 and len(self.default_directory) == 0
                and len(self.default_filename) == 0):
            default_directory, default_filename = os.path.split(
                self.default_path)
        else:
            default_directory = self.default_directory
            default_filename = self.default_filename

        if self.action == "open":
            style = wx.FD_OPEN
        elif self.action == "open files":
            style = wx.FD_OPEN | wx.FD_MULTIPLE
        else:
            style = wx.FD_SAVE | wx.FD_OVERWRITE_PROMPT

        # Create the actual dialog.
        dialog = wx.FileDialog(
            parent,
            self.title,
            defaultDir=default_directory,
            defaultFile=default_filename,
            style=style,
            wildcard=self.wildcard.rstrip("|"),
        )

        dialog.SetFilterIndex(self.wildcard_index)

        return dialog

    # ------------------------------------------------------------------------
    # Trait handlers.
    # ------------------------------------------------------------------------

    def _wildcard_default(self):
        """ Return the default wildcard. """

        return self.WILDCARD_ALL
Example #25
0

from operator import itemgetter

from traits.api import BaseTraitHandler, CTrait, Enum, TraitError

from .ui_traits import SequenceTypes



# -------------------------------------------------------------------------
#  Trait definitions:
# -------------------------------------------------------------------------

# Layout orientation for a control and its associated editor
Orientation = Enum("horizontal", "vertical")

# Docking drag bar style:
DockStyle = Enum("horizontal", "vertical", "tab", "fixed")


def user_name_for(name):
    """ Returns a "user-friendly" name for a specified trait.
    """
    name = name.replace("_", " ")
    name = name[:1].upper() + name[1:]
    result = ""
    last_lower = 0
    for c in name:
        if c.isupper() and last_lower:
            result += " "
Example #26
0
class SwiftConsole(HasTraits):
    """Traits-defined Swift Console.

    link : object
      Serial driver
    update : bool
      Update the firmware
    log_level_filter : str
      Syslog string, one of "ERROR", "WARNING", "INFO", "DEBUG".
    skip_settings : bool
      Don't read the device settings. Set to False when the console is reading
      from a network connection only.

    """

    link = Instance(sbpc.Handler)
    console_output = Instance(OutputList())
    python_console_env = Dict
    device_serial = Str('')
    dev_id = Str('')
    tracking_view = Instance(TrackingView)
    solution_view = Instance(SolutionView)
    baseline_view = Instance(BaselineView)
    observation_view = Instance(ObservationView)
    networking_view = Instance(SbpRelayView)
    observation_view_base = Instance(ObservationView)
    system_monitor_view = Instance(SystemMonitorView)
    settings_view = Instance(SettingsView)
    update_view = Instance(UpdateView)
    imu_view = Instance(IMUView)
    mag_view = Instance(MagView)
    spectrum_analyzer_view = Instance(SpectrumAnalyzerView)
    skylark_view = Instance(SkylarkView)
    log_level_filter = Enum(list(SYSLOG_LEVELS.itervalues()))
    """"
  mode : baseline and solution view - SPP, Fixed or Float
  num_sat : baseline and solution view - number of satellites
  port : which port is Swift Device is connected to
  directory_name : location of logged files
  json_logging : enable JSON logging
  csv_logging : enable CSV logging

  """

    mode = Str('')
    num_sats = Int(0)
    cnx_desc = Str('')
    latency = Str('')
    uuid = Str('')
    directory_name = Directory
    json_logging = Bool(True)
    csv_logging = Bool(False)
    cnx_icon = Str('')
    heartbeat_count = Int()
    last_timer_heartbeat = Int()
    solid_connection = Bool(False)

    csv_logging_button = SVGButton(
        toggle=True,
        label='CSV log',
        tooltip='start CSV logging',
        toggle_tooltip='stop CSV logging',
        filename=resource_filename('console/images/iconic/pause.svg'),
        toggle_filename=resource_filename('console/images/iconic/play.svg'),
        orientation='vertical',
        width=2,
        height=2,
    )
    json_logging_button = SVGButton(
        toggle=True,
        label='JSON log',
        tooltip='start JSON logging',
        toggle_tooltip='stop JSON logging',
        filename=resource_filename('console/images/iconic/pause.svg'),
        toggle_filename=resource_filename('console/images/iconic/play.svg'),
        orientation='vertical',
        width=2,
        height=2,
    )
    paused_button = SVGButton(
        label='',
        tooltip='Pause console update',
        toggle_tooltip='Resume console update',
        toggle=True,
        filename=resource_filename('console/images/iconic/pause.svg'),
        toggle_filename=resource_filename('console/images/iconic/play.svg'),
        width=8,
        height=8)
    clear_button = SVGButton(
        label='',
        tooltip='Clear console buffer',
        filename=resource_filename('console/images/iconic/x.svg'),
        width=8,
        height=8)

    view = View(VSplit(
        Tabbed(Item('tracking_view', style='custom', label='Tracking'),
               Item('solution_view', style='custom', label='Solution'),
               Item('baseline_view', style='custom', label='Baseline'),
               VSplit(
                   Item('observation_view', style='custom', show_label=False),
                   Item('observation_view_base',
                        style='custom',
                        show_label=False),
                   label='Observations',
               ),
               Item('settings_view', style='custom', label='Settings'),
               Item('update_view', style='custom', label='Update'),
               Tabbed(Item('system_monitor_view',
                           style='custom',
                           label='System Monitor'),
                      Item('imu_view', style='custom', label='IMU'),
                      Item('mag_view', style='custom', label='Magnetometer'),
                      Item('networking_view',
                           label='Networking',
                           style='custom',
                           show_label=False),
                      Item('spectrum_analyzer_view',
                           label='Spectrum Analyzer',
                           style='custom'),
                      label='Advanced',
                      show_labels=False),
               Item('skylark_view', style='custom', label='Skylark'),
               show_labels=False),
        VGroup(
            VGroup(
                HGroup(
                    Spring(width=4, springy=False),
                    Item('paused_button',
                         show_label=False,
                         padding=0,
                         width=8,
                         height=8),
                    Item('clear_button', show_label=False, width=8, height=8),
                    Item('', label='Console Log', emphasized=True),
                    Item('csv_logging_button',
                         emphasized=True,
                         show_label=False,
                         width=12,
                         height=-30,
                         padding=0),
                    Item('json_logging_button',
                         emphasized=True,
                         show_label=False,
                         width=12,
                         height=-30,
                         padding=0),
                    Item(
                        'directory_name',
                        show_label=False,
                        springy=True,
                        tooltip=
                        'Choose location for file logs. Default is home/SwiftNav.',
                        height=-25,
                        enabled_when='not(json_logging or csv_logging)',
                        editor_args={'auto_set': True}),
                    UItem(
                        'log_level_filter',
                        style='simple',
                        padding=0,
                        height=8,
                        show_label=True,
                        tooltip=
                        'Show log levels up to and including the selected level of severity.\nThe CONSOLE log level is always visible.'
                    ),
                ),
                Item('console_output',
                     style='custom',
                     editor=InstanceEditor(),
                     height=125,
                     show_label=False,
                     full_size=True),
            ),
            HGroup(
                Spring(width=4, springy=False),
                Item('',
                     label='Interface:',
                     emphasized=True,
                     tooltip='Interface for communicating with Swift device'),
                Item('cnx_desc', show_label=False, style='readonly'),
                Item('',
                     label='FIX TYPE:',
                     emphasized=True,
                     tooltip='Device Mode: SPS, Float RTK, Fixed RTK'),
                Item('mode', show_label=False, style='readonly'),
                Item('',
                     label='#Sats:',
                     emphasized=True,
                     tooltip='Number of satellites used in solution'),
                Item('num_sats', padding=2, show_label=False,
                     style='readonly'),
                Item('',
                     label='Base Latency:',
                     emphasized=True,
                     tooltip='Corrections latency (-1 means no corrections)'),
                Item('latency', padding=2, show_label=False, style='readonly'),
                Item('',
                     label='Device UUID:',
                     emphasized=True,
                     tooltip='Universally Unique Device Identifier (UUID)'),
                Item('uuid',
                     padding=2,
                     show_label=False,
                     style='readonly',
                     width=6),
                Spring(springy=True),
                Item('cnx_icon',
                     show_label=False,
                     padding=0,
                     width=8,
                     height=8,
                     visible_when='solid_connection',
                     springy=False,
                     editor=ImageEditor(
                         allow_clipping=False,
                         image=ImageResource(
                             resource_filename(
                                 'console/images/iconic/arrows_blue.png')))),
                Item('cnx_icon',
                     show_label=False,
                     padding=0,
                     width=8,
                     height=8,
                     visible_when='not solid_connection',
                     springy=False,
                     editor=ImageEditor(
                         allow_clipping=False,
                         image=ImageResource(
                             resource_filename(
                                 'console/images/iconic/arrows_grey.png')))),
                Spring(width=4, height=-2, springy=False),
            ),
            Spring(height=1, springy=False),
        ),
    ),
                icon=icon,
                resizable=True,
                width=800,
                height=600,
                handler=ConsoleHandler(),
                title=CONSOLE_TITLE)

    def print_message_callback(self, sbp_msg, **metadata):
        try:
            encoded = sbp_msg.payload.encode('ascii', 'ignore')
            for eachline in reversed(encoded.split('\n')):
                self.console_output.write_level(
                    eachline, str_to_log_level(eachline.split(':')[0]))
        except UnicodeDecodeError:
            print("Critical Error encoding the serial stream as ascii.")

    def log_message_callback(self, sbp_msg, **metadata):
        try:
            encoded = sbp_msg.text.encode('ascii', 'ignore')
            for eachline in reversed(encoded.split('\n')):
                self.console_output.write_level(eachline, sbp_msg.level)
        except UnicodeDecodeError:
            print("Critical Error encoding the serial stream as ascii.")

    def ext_event_callback(self, sbp_msg, **metadata):
        e = MsgExtEvent(sbp_msg)
        print(
            'External event: %s edge on pin %d at wn=%d, tow=%d, time qual=%s'
            % ("Rising" if
               (e.flags &
                (1 << 0)) else "Falling", e.pin, e.wn, e.tow, "good" if
               (e.flags & (1 << 1)) else "unknown"))

    def cmd_resp_callback(self, sbp_msg, **metadata):
        r = MsgCommandResp(sbp_msg)
        print("Received a command response message with code {0}".format(
            r.code))

    def _paused_button_fired(self):
        self.console_output.paused = not self.console_output.paused

    def _log_level_filter_changed(self):
        """
        Takes log level enum and translates into the mapped integer.
        Integer stores the current filter value inside OutputList.
        """
        self.console_output.log_level_filter = str_to_log_level(
            self.log_level_filter)

    def _clear_button_fired(self):
        self.console_output.clear()

    def _directory_name_changed(self):
        if self.baseline_view and self.solution_view:
            self.baseline_view.directory_name_b = self.directory_name
            self.solution_view.directory_name_p = self.directory_name
            self.solution_view.directory_name_v = self.directory_name
        if self.observation_view and self.observation_view_base:
            self.observation_view.dirname = self.directory_name
            self.observation_view_base.dirname = self.directory_name

    def check_heartbeat(self):
        # if our heartbeat hasn't changed since the last timer interval the connection must have dropped
        if self.heartbeat_count == self.last_timer_heartbeat:
            self.solid_connection = False
        else:
            self.solid_connection = True
        self.last_timer_heartbeat = self.heartbeat_count

    def update_on_heartbeat(self, sbp_msg, **metadata):
        self.heartbeat_count += 1
        # First initialize the state to nothing, if we can't update, it will be none
        temp_mode = "None"
        temp_num_sats = 0
        view = None
        if self.baseline_view and self.solution_view:
            # If we have a recent baseline update, we use the baseline info
            if time.time() - self.baseline_view.last_btime_update < 10:
                view = self.baseline_view
            # Otherwise, if we have a recent SPP update, we use the SPP
            elif time.time() - self.solution_view.last_stime_update < 10:
                view = self.solution_view
            if view:
                if view.last_soln:
                    # if all is well we update state
                    temp_mode = mode_dict.get(get_mode(view.last_soln),
                                              EMPTY_STR)
                    temp_num_sats = view.last_soln.n_sats

        self.mode = temp_mode
        self.num_sats = temp_num_sats

        if self.settings_view:  # for auto populating surveyed fields
            self.settings_view.lat = self.solution_view.latitude
            self.settings_view.lon = self.solution_view.longitude
            self.settings_view.alt = self.solution_view.altitude
        if self.system_monitor_view:
            if self.system_monitor_view.msg_obs_window_latency_ms != -1:
                self.latency = "{0} ms".format(
                    self.system_monitor_view.msg_obs_window_latency_ms)
            else:
                self.latency = EMPTY_STR

    def _csv_logging_button_action(self):
        if self.csv_logging and self.baseline_view.logging_b and self.solution_view.logging_p and self.solution_view.logging_v:
            print("Stopped CSV logging")
            self.csv_logging = False
            self.baseline_view.logging_b = False
            self.solution_view.logging_p = False
            self.solution_view.logging_v = False

        else:
            print("Started CSV logging at %s" % self.directory_name)
            self.csv_logging = True
            self.baseline_view.logging_b = True
            self.solution_view.logging_p = True
            self.solution_view.logging_v = True

    def _start_json_logging(self, override_filename=None):
        if override_filename:
            filename = override_filename
        else:
            filename = time.strftime("swift-gnss-%Y%m%d-%H%M%S.sbp.json")
            filename = os.path.normpath(
                os.path.join(self.directory_name, filename))
        self.logger = s.get_logger(True, filename, self.expand_json)
        self.forwarder = sbpc.Forwarder(self.link, self.logger)
        self.forwarder.start()
        if self.settings_view:
            self.settings_view._settings_read_button_fired()

    def _stop_json_logging(self):
        fwd = self.forwarder
        fwd.stop()
        self.logger.flush()
        self.logger.close()

    def _json_logging_button_action(self):
        if self.first_json_press and self.json_logging:
            print(
                "JSON Logging initiated via CMD line.  Please press button again to stop logging"
            )
        elif self.json_logging:
            self._stop_json_logging()
            self.json_logging = False
            print("Stopped JSON logging")
        else:
            self._start_json_logging()
            self.json_logging = True
        self.first_json_press = False

    def _json_logging_button_fired(self):
        if not os.path.exists(self.directory_name) and not self.json_logging:
            confirm_prompt = CallbackPrompt(
                title="Logging directory creation",
                actions=[ok_button],
                callback=self._json_logging_button_action)
            confirm_prompt.text = "\nThe selected logging directory does not exist and will be created."
            confirm_prompt.run(block=False)
        else:
            self._json_logging_button_action()

    def _csv_logging_button_fired(self):
        if not os.path.exists(self.directory_name) and not self.csv_logging:
            confirm_prompt = CallbackPrompt(
                title="Logging directory creation",
                actions=[ok_button],
                callback=self._csv_logging_button_action)
            confirm_prompt.text = "\nThe selected logging directory does not exist and will be created."
            confirm_prompt.run(block=False)
        else:
            self._csv_logging_button_action()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.console_output.close()

    def __init__(self,
                 link,
                 update,
                 log_level_filter,
                 skip_settings=False,
                 error=False,
                 cnx_desc=None,
                 json_logging=False,
                 log_dirname=None,
                 override_filename=None,
                 log_console=False,
                 networking=None,
                 connection_info=None,
                 expand_json=False):
        self.error = error
        self.cnx_desc = cnx_desc
        self.connection_info = connection_info
        self.dev_id = cnx_desc
        self.num_sats = 0
        self.mode = ''
        self.forwarder = None
        self.latency = '--'
        self.expand_json = expand_json
        # if we have passed a logfile, we set our directory to it
        override_filename = override_filename

        if log_dirname:
            self.directory_name = log_dirname
            if override_filename:
                override_filename = os.path.join(log_dirname,
                                                 override_filename)
        else:
            self.directory_name = swift_path

        # Start swallowing sys.stdout and sys.stderr
        self.console_output = OutputList(tfile=log_console,
                                         outdir=self.directory_name)
        sys.stdout = self.console_output
        self.console_output.write("Console: " + CONSOLE_VERSION +
                                  " starting...")
        if not error:
            sys.stderr = self.console_output

        self.log_level_filter = log_level_filter
        self.console_output.log_level_filter = str_to_log_level(
            log_level_filter)
        try:
            self.link = link
            self.link.add_callback(self.print_message_callback,
                                   SBP_MSG_PRINT_DEP)
            self.link.add_callback(self.log_message_callback, SBP_MSG_LOG)
            self.link.add_callback(self.ext_event_callback, SBP_MSG_EXT_EVENT)
            self.link.add_callback(self.cmd_resp_callback,
                                   SBP_MSG_COMMAND_RESP)
            self.link.add_callback(self.update_on_heartbeat, SBP_MSG_HEARTBEAT)
            self.dep_handler = DeprecatedMessageHandler(link)
            settings_read_finished_functions = []
            self.tracking_view = TrackingView(self.link)
            self.solution_view = SolutionView(self.link,
                                              dirname=self.directory_name)
            self.baseline_view = BaselineView(self.link,
                                              dirname=self.directory_name)
            self.observation_view = ObservationView(
                self.link,
                name='Local',
                relay=False,
                dirname=self.directory_name)
            self.observation_view_base = ObservationView(
                self.link,
                name='Remote',
                relay=True,
                dirname=self.directory_name)
            self.system_monitor_view = SystemMonitorView(self.link)
            self.update_view = UpdateView(self.link,
                                          download_dir=swift_path,
                                          prompt=update,
                                          connection_info=self.connection_info)
            self.imu_view = IMUView(self.link)
            self.mag_view = MagView(self.link)
            self.spectrum_analyzer_view = SpectrumAnalyzerView(self.link)
            settings_read_finished_functions.append(
                self.update_view.compare_versions)
            if networking:
                from ruamel.yaml import YAML
                yaml = YAML(typ='safe')
                try:
                    networking_dict = yaml.load(networking)
                    networking_dict.update({'show_networking': True})
                except yaml.YAMLError:
                    print(
                        "Unable to interpret networking cmdline argument.  It will be ignored."
                    )
                    import traceback
                    print(traceback.format_exc())
                    networking_dict = {'show_networking': True}
            else:
                networking_dict = {}
            networking_dict.update(
                {'whitelist': [SBP_MSG_POS_LLH, SBP_MSG_HEARTBEAT]})
            self.networking_view = SbpRelayView(self.link, **networking_dict)
            self.skylark_view = SkylarkView()
            self.json_logging = json_logging
            self.csv_logging = False
            self.first_json_press = True
            if json_logging:
                self._start_json_logging(override_filename)
                self.json_logging = True
            # we set timer interval to 1200 milliseconds because we expect a heartbeat each second
            self.timer_cancel = call_repeatedly(1.2, self.check_heartbeat)

            # Once we have received the settings, update device_serial with
            # the Swift serial number which will be displayed in the window
            # title. This callback will also update the header route as used
            # by the networking view.

            def update_serial():
                mfg_id = None
                try:
                    self.uuid = self.settings_view.settings['system_info'][
                        'uuid'].value
                    mfg_id = self.settings_view.settings['system_info'][
                        'serial_number'].value
                except KeyError:
                    pass
                if mfg_id:
                    self.device_serial = 'PK' + str(mfg_id)
                self.skylark_view.set_uuid(self.uuid)
                self.networking_view.set_route(uuid=self.uuid,
                                               serial_id=mfg_id)
                if self.networking_view.connect_when_uuid_received:
                    self.networking_view._connect_rover_fired()

            settings_read_finished_functions.append(update_serial)
            self.settings_view = SettingsView(self.link,
                                              settings_read_finished_functions,
                                              skip=skip_settings)
            self.update_view.settings = self.settings_view.settings
            self.python_console_env = {
                'send_message': self.link,
                'link': self.link,
            }
            self.python_console_env.update(
                self.tracking_view.python_console_cmds)
            self.python_console_env.update(
                self.solution_view.python_console_cmds)
            self.python_console_env.update(
                self.baseline_view.python_console_cmds)
            self.python_console_env.update(
                self.observation_view.python_console_cmds)
            self.python_console_env.update(
                self.networking_view.python_console_cmds)
            self.python_console_env.update(
                self.system_monitor_view.python_console_cmds)
            self.python_console_env.update(
                self.update_view.python_console_cmds)
            self.python_console_env.update(self.imu_view.python_console_cmds)
            self.python_console_env.update(self.mag_view.python_console_cmds)
            self.python_console_env.update(
                self.settings_view.python_console_cmds)
            self.python_console_env.update(
                self.spectrum_analyzer_view.python_console_cmds)

        except:  # noqa
            import traceback
            traceback.print_exc()
            if self.error:
                sys.exit(1)
class RangeSelectionOverlay(AbstractOverlay):
    """ Highlights the selection region on a component.

    Looks at a given metadata field of self.component for regions to draw as
    selected.
    """

    #: The axis to which this tool is perpendicular.
    axis = Enum("index", "value")

    #: Mapping from screen space to data space. By default, it is just
    #: self.component.
    plot = Property(depends_on='component')

    #: The mapper (and associated range) that drive this RangeSelectionOverlay.
    #: By default, this is the mapper on self.plot that corresponds to self.axis.
    mapper = Instance(AbstractMapper)

    #: The element of an (x,y) tuple that corresponds to the axis index.
    #: By default, this is set based on self.asix and self.plot.orientation,
    #: but it can be overriden and set to 0 or 1.
    axis_index = Property

    #: The name of the metadata to look at for dataspace bounds. The metadata
    #: can be either a tuple (dataspace_start, dataspace_end) in "selections" or
    #: a boolean array mask of seleted dataspace points with any other name
    metadata_name = Str("selections")

    #------------------------------------------------------------------------
    # Appearance traits
    #------------------------------------------------------------------------

    #: The color of the selection border line.
    border_color = ColorTrait("dodgerblue")
    #: The width, in pixels, of the selection border line.
    border_width = Float(1.0)
    #: The line style of the selection border line.
    border_style = LineStyle("solid")
    #: The color to fill the selection region.
    fill_color = ColorTrait("lightskyblue")
    #: The transparency of the fill color.
    alpha = Float(0.3)

    #------------------------------------------------------------------------
    # AbstractOverlay interface
    #------------------------------------------------------------------------

    def overlay(self, component, gc, view_bounds=None, mode="normal"):
        """ Draws this component overlaid on another component.

        Overrides AbstractOverlay.
        """
        axis_ndx = self.axis_index
        lower_left = [0, 0]
        upper_right = [0, 0]

        # Draw the selection
        coords = self._get_selection_screencoords()
        for coord in coords:
            start, end = coord
            lower_left[axis_ndx] = start
            lower_left[1 - axis_ndx] = component.position[1 - axis_ndx]
            upper_right[axis_ndx] = end - start
            upper_right[1 - axis_ndx] = component.bounds[1 - axis_ndx]

            with gc:
                gc.clip_to_rect(component.x, component.y, component.width,
                                component.height)
                gc.set_alpha(self.alpha)
                gc.set_fill_color(self.fill_color_)
                gc.set_stroke_color(self.border_color_)
                gc.set_line_width(self.border_width)
                gc.set_line_dash(self.border_style_)
                gc.draw_rect((lower_left[0], lower_left[1], upper_right[0],
                              upper_right[1]))

    #------------------------------------------------------------------------
    # Private methods
    #------------------------------------------------------------------------

    def _get_selection_screencoords(self):
        """ Returns a tuple of (x1, x2) screen space coordinates of the start
        and end selection points.

        If there is no current selection, then returns an empty list.
        """
        ds = getattr(self.plot, self.axis)
        selection = ds.metadata.get(self.metadata_name, None)
        if selection is None:
            return []

        # "selections" metadata must be a tuple
        if self.metadata_name == "selections" or \
                (selection is not None and isinstance(selection, tuple)):
            if selection is not None and len(selection) == 2:
                return [self.mapper.map_screen(array(selection))]
            else:
                return []
        # All other metadata is interpreted as a mask on dataspace
        else:
            ar = arange(0, len(selection), 1)
            runs = arg_find_runs(ar[selection])
            coords = []
            for inds in runs:
                start = ds._data[ar[selection][inds[0]]]
                end = ds._data[ar[selection][inds[1] - 1]]
                coords.append(self.mapper.map_screen(array((start, end))))
            return coords

    def _determine_axis(self):
        """ Determines which element of an (x,y) coordinate tuple corresponds
        to the tool's axis of interest.

        This method is only called if self._axis_index hasn't been set (or is
        None).
        """
        if self.axis == "index":
            if self.plot.orientation == "h":
                return 0
            else:
                return 1
        else:  # self.axis == "value"
            if self.plot.orientation == "h":
                return 1
            else:
                return 0

    #------------------------------------------------------------------------
    # Trait event handlers
    #------------------------------------------------------------------------

    def _component_changed(self, old, new):
        self._attach_metadata_handler(old, new)
        return

    def _axis_changed(self, old, new):
        self._attach_metadata_handler(old, new)
        return

    def _attach_metadata_handler(self, old, new):
        # This is used to attach a listener to the datasource so that when
        # its metadata has been updated, we catch the event and update properly
        if not self.plot:
            return

        datasource = getattr(self.plot, self.axis)
        if old:
            datasource.on_trait_change(self._metadata_change_handler,
                                       "metadata_changed",
                                       remove=True)
        if new:
            datasource.on_trait_change(self._metadata_change_handler,
                                       "metadata_changed")
        return

    def _metadata_change_handler(self, event):
        self.component.request_redraw()
        return

    #------------------------------------------------------------------------
    # Default initializers
    #------------------------------------------------------------------------

    def _mapper_default(self):
        # If the plot's mapper is a GridMapper, return either its
        # x mapper or y mapper

        mapper = getattr(self.plot, self.axis + "_mapper")

        if isinstance(mapper, GridMapper):
            if self.axis == 'index':
                return mapper._xmapper
            else:
                return mapper._ymapper
        else:
            return mapper

    #------------------------------------------------------------------------
    # Property getter/setters
    #------------------------------------------------------------------------

    @cached_property
    def _get_plot(self):
        return self.component

    @cached_property
    def _get_axis_index(self):
        return self._determine_axis()
Example #28
0
class PortChooser(HasTraits):
    port = Str(None)
    ports = List()
    mode = Enum(cnx_type_list)
    flow_control = Enum(flow_control_options_list)
    ip_port = Int(55555)
    ip_address = Str('192.168.0.222')
    choose_baud = Bool(True)
    baudrate = Int()
    refresh_ports_button = SVGButton(
        label='',
        tooltip='Refresh Port List',
        filename=resource_filename(
            'console/images/fontawesome/refresh_blue.svg'),
        allow_clipping=False,
        width_padding=4,
        height_padding=4)

    traits_view = View(
        VGroup(
            Spring(height=8),
            HGroup(
                Spring(width=-2, springy=False),
                Item('mode',
                     style='custom',
                     editor=EnumEditor(values=cnx_type_list,
                                       cols=2,
                                       format_str='%s'),
                     show_label=False)),
            HGroup(VGroup(
                Label('Serial Device:'),
                HGroup(
                    Item('port',
                         editor=EnumEditor(name='ports'),
                         show_label=False,
                         springy=True),
                    Item('refresh_ports_button',
                         show_label=False,
                         padding=0,
                         height=-20,
                         width=-20),
                ),
            ),
                   VGroup(
                       Label('Baudrate:'),
                       Item('baudrate',
                            editor=EnumEditor(values=BAUD_LIST),
                            show_label=False,
                            visible_when='choose_baud'),
                       Item('baudrate',
                            show_label=False,
                            visible_when='not choose_baud',
                            style='readonly'),
                   ),
                   VGroup(
                       Label('Flow Control:'),
                       Item('flow_control',
                            editor=EnumEditor(values=flow_control_options_list,
                                              format_str='%s'),
                            show_label=False),
                   ),
                   visible_when="mode==\'Serial/USB\'"),
            HGroup(VGroup(
                Label('IP Address:'),
                Item('ip_address',
                     label="IP Address",
                     style='simple',
                     show_label=False,
                     height=-24),
            ),
                   VGroup(
                       Label('IP Port:'),
                       Item('ip_port',
                            label="IP Port",
                            style='simple',
                            show_label=False,
                            height=-24),
                   ),
                   Spring(),
                   visible_when="mode==\'TCP/IP\'"),
        ),
        buttons=['OK', 'Cancel'],
        default_button='OK',
        close_result=False,
        icon=icon,
        width=460,
        title='Swift Console v{0} - Select Interface'.format(CONSOLE_VERSION))

    def refresh_ports(self):
        """
        This method refreshes the port list
        """
        try:
            self.ports = [p for p, _, _ in s.get_ports()]
        except TypeError:
            pass

    def _refresh_ports_button_fired(self):
        self.refresh_ports()

    def __init__(self, baudrate=None):
        self.refresh_ports()
        # As default value, use the first city in the list:
        try:
            self.port = self.ports[0]
        except IndexError:
            pass
        if baudrate not in BAUD_LIST:
            self.choose_baud = False
        self.baudrate = baudrate
class HeadViewController(HasTraits):
    """Set head views for the given coordinate system.

    Parameters
    ----------
    system : 'RAS' | 'ALS' | 'ARI'
        Coordinate system described as initials for directions associated with
        the x, y, and z axes. Relevant terms are: Anterior, Right, Left,
        Superior, Inferior.
    """

    system = Enum("RAS",
                  "ALS",
                  "ARI",
                  desc="Coordinate system: directions of "
                  "the x, y, and z axis.")

    right = Button()
    front = Button()
    left = Button()
    top = Button()
    interaction = Enum('trackball', 'terrain')

    scale = Float(0.16)

    scene = Instance(MlabSceneModel)

    view = View(
        VGroup(VGrid('0',
                     Item('top', width=_VIEW_BUTTON_WIDTH),
                     '0',
                     Item('right', width=_VIEW_BUTTON_WIDTH),
                     Item('front', width=_VIEW_BUTTON_WIDTH),
                     Item('left', width=_VIEW_BUTTON_WIDTH),
                     columns=3,
                     show_labels=False),
               '_',
               HGroup(
                   Item('scale',
                        label='Scale',
                        editor=laggy_float_editor_headscale,
                        width=_SCALE_WIDTH,
                        show_label=True),
                   Item('interaction',
                        tooltip='Mouse interaction mode',
                        show_label=False), Spring()),
               show_labels=False))

    @on_trait_change('scene.activated')
    def _init_view(self):
        self.scene.parallel_projection = True

        # apparently scene,activated happens several times
        if self.scene.renderer:
            self.sync_trait('scale', self.scene.camera, 'parallel_scale')
            # and apparently this does not happen by default:
            self.on_trait_change(self.scene.render, 'scale')
            self.interaction = self.interaction  # could be delayed

    @on_trait_change('interaction')
    def on_set_interaction(self, _, interaction):
        if self.scene is None or self.scene.interactor is None:
            return
        # Ensure we're in the correct orientation for the
        # InteractorStyleTerrain to have the correct "up"
        self.on_set_view('front', '')
        self.scene.mlab.draw()
        self.scene.interactor.interactor_style = \
            tvtk.InteractorStyleTerrain() if interaction == 'terrain' else \
            tvtk.InteractorStyleTrackballCamera()
        # self.scene.interactor.interactor_style.
        self.on_set_view('front', '')
        self.scene.mlab.draw()

    @on_trait_change('top,left,right,front')
    def on_set_view(self, view, _):
        if self.scene is None:
            return

        system = self.system
        kwargs = dict(ALS=dict(front=(0, 90, -90),
                               left=(90, 90, 180),
                               right=(-90, 90, 0),
                               top=(0, 0, -90)),
                      RAS=dict(front=(90., 90., 180),
                               left=(180, 90, 90),
                               right=(0., 90, 270),
                               top=(90, 0, 180)),
                      ARI=dict(front=(0, 90, 90),
                               left=(-90, 90, 180),
                               right=(90, 90, 0),
                               top=(0, 180, 90)))
        if system not in kwargs:
            raise ValueError("Invalid system: %r" % system)
        if view not in kwargs[system]:
            raise ValueError("Invalid view: %r" % view)
        kwargs = dict(
            zip(('azimuth', 'elevation', 'roll'), kwargs[system][view]))
        kwargs['focalpoint'] = (0., 0., 0.)
        with SilenceStdout():
            self.scene.mlab.view(distance=None,
                                 reset_roll=True,
                                 figure=self.scene.mayavi_scene,
                                 **kwargs)
Example #30
0
class Slider(Component):
    """ A horizontal or vertical slider bar """

    #------------------------------------------------------------------------
    # Model traits
    #------------------------------------------------------------------------

    min = Float()

    max = Float()

    value = Float()

    # The number of ticks to show on the slider.
    num_ticks = Int(4)

    #------------------------------------------------------------------------
    # Bar and endcap appearance
    #------------------------------------------------------------------------

    # Whether this is a horizontal or vertical slider
    orientation = Enum("h", "v")

    # The thickness, in pixels, of the lines used to render the ticks,
    # endcaps, and main slider bar.
    bar_width = Int(4)

    bar_color = ColorTrait("black")

    # Whether or not to render endcaps on the slider bar
    endcaps = Bool(True)

    # The extent of the endcaps, in pixels.  This is a read-only property,
    # since the endcap size can be set as either a fixed number of pixels or
    # a percentage of the widget's size in the transverse direction.
    endcap_size = Property

    # The extent of the tickmarks, in pixels.  This is a read-only property,
    # since the endcap size can be set as either a fixed number of pixels or
    # a percentage of the widget's size in the transverse direction.
    tick_size = Property

    #------------------------------------------------------------------------
    # Slider appearance
    #------------------------------------------------------------------------

    # The kind of marker to use for the slider.
    slider = SliderMarkerTrait("rect")

    # If the slider marker is "rect", this is the thickness of the slider,
    # i.e. its extent in the dimension parallel to the long axis of the widget.
    # For other slider markers, this has no effect.
    slider_thickness = Int(9)

    # The size of the slider, in pixels.  This is a read-only property, since
    # the slider size can be set as either a fixed number of pixels or a
    # percentage of the widget's size in the transverse direction.
    slider_size = Property

    # For slider markers with a filled area, this is the color of the filled
    # area.  For slider markers that are just lines/strokes (e.g. cross, plus),
    # this is the color of the stroke.
    slider_color = ColorTrait("red")

    # For slider markers with a filled area, this is the color of the outline
    # border drawn around the filled area.  For slider markers that have just
    # lines/strokes, this has no effect.
    slider_border = ColorTrait("none")

    # For slider markers with a filled area, this is the width, in pixels,
    # of the outline around the area.  For slider markers that are just lines/
    # strokes, this is the thickness of the stroke.
    slider_outline_width = Int(1)

    # The kiva.CompiledPath representing the custom path to render for the
    # slider, if the **slider** trait is set to "custom".
    custom_slider = Any()

    #------------------------------------------------------------------------
    # Interaction traits
    #------------------------------------------------------------------------

    # Can this slider be interacted with, or is it just a display
    interactive = Bool(True)

    mouse_button = Enum("left", "right")

    event_state = Enum("normal", "dragging")

    #------------------------------------------------------------------------
    # Private traits
    #------------------------------------------------------------------------

    # Returns the coordinate index (0 or 1) corresponding to our orientation.
    # Used internally; read-only property.
    axis_ndx = Property()

    _slider_size_mode = Enum("fixed", "percent")
    _slider_percent = Float(0.0)
    _cached_slider_size = Int(10)

    _endcap_size_mode = Enum("fixed", "percent")
    _endcap_percent = Float(0.0)
    _cached_endcap_size = Int(20)

    _tick_size_mode = Enum("fixed", "percent")
    _tick_size_percent = Float(0.0)
    _cached_tick_size = Int(20)

    # A tuple of (dx, dy) of the difference between the mouse position and
    # center of the slider.
    _offset = Any((0, 0))

    def set_range(self, min, max):
        self.min = min
        self.max = max

    def map_screen(self, val):
        """ Returns an (x,y) coordinate corresponding to the location of
        **val** on the slider.
        """
        # Some local variables to handle orientation dependence
        axis_ndx = self.axis_ndx
        other_ndx = 1 - axis_ndx
        screen_low = self.position[axis_ndx]
        screen_high = screen_low + self.bounds[axis_ndx]

        # The return coordinate.  The return value along the non-primary
        # axis will be the same in all cases.
        coord = [0, 0]
        coord[
            other_ndx] = self.position[other_ndx] + self.bounds[other_ndx] / 2

        # Handle exceptional/boundary cases
        if val <= self.min:
            coord[axis_ndx] = screen_low
            return coord
        elif val >= self.max:
            coord[axis_ndx] = screen_high
            return coord
        elif self.min == self.max:
            coord[axis_ndx] = (screen_low + screen_high) / 2
            return coord

        # Handle normal cases
        coord[axis_ndx] = (val - self.min) / (
            self.max - self.min) * self.bounds[axis_ndx] + screen_low
        return coord

    def map_data(self, x, y, clip=True):
        """ Returns a value between min and max that corresponds to the given
        x and y values.

        Parameters
        ==========
        x, y : Float
            The screen coordinates to map
        clip : Bool (default=True)
            Whether points outside the range should be clipped to the max
            or min value of the slider (depending on which it's closer to)

        Returns
        =======
        value : Float
        """
        # Some local variables to handle orientation dependence
        axis_ndx = self.axis_ndx
        other_ndx = 1 - axis_ndx
        screen_low = self.position[axis_ndx]
        screen_high = screen_low + self.bounds[axis_ndx]
        if self.orientation == "h":
            coord = x
        else:
            coord = y

        # Handle exceptional/boundary cases
        if coord >= screen_high:
            return self.max
        elif coord <= screen_low:
            return self.min
        elif screen_high == screen_low:
            return (self.max + self.min) / 2

        # Handle normal cases
        return (coord - screen_low) /self.bounds[axis_ndx] * \
                    (self.max - self.min) + self.min

    def set_slider_pixels(self, pixels):
        """ Sets the width of the slider to be a fixed number of pixels

        Parameters
        ==========
        pixels : int
            The number of pixels wide that the slider should be
        """
        self._slider_size_mode = "fixed"
        self._cached_slider_size = pixels

    def set_slider_percent(self, percent):
        """ Sets the width of the slider to be a percentage of the width
        of the slider widget.

        Parameters
        ==========
        percent : float
            The percentage, between 0.0 and 1.0
        """
        self._slider_size_mode = "percent"
        self._slider_percent = percent
        self._update_sizes()

    def set_endcap_pixels(self, pixels):
        """ Sets the width of the endcap to be a fixed number of pixels

        Parameters
        ==========
        pixels : int
            The number of pixels wide that the endcap should be
        """
        self._endcap_size_mode = "fixed"
        self._cached_endcap_size = pixels

    def set_endcap_percent(self, percent):
        """ Sets the width of the endcap to be a percentage of the width
        of the endcap widget.

        Parameters
        ==========
        percent : float
            The percentage, between 0.0 and 1.0
        """
        self._endcap_size_mode = "percent"
        self._endcap_percent = percent
        self._update_sizes()

    def set_tick_pixels(self, pixels):
        """ Sets the width of the tick marks to be a fixed number of pixels

        Parameters
        ==========
        pixels : int
            The number of pixels wide that the endcap should be
        """
        self._tick_size_mode = "fixed"
        self._cached_tick_size = pixels

    def set_tick_percent(self, percent):
        """ Sets the width of the tick marks to be a percentage of the width
        of the endcap widget.

        Parameters
        ==========
        percent : float
            The percentage, between 0.0 and 1.0
        """
        self._tick_size_mode = "percent"
        self._tick_percent = percent
        self._update_sizes()

    #------------------------------------------------------------------------
    # Rendering methods
    #------------------------------------------------------------------------

    def _draw_mainlayer(self, gc, view_bounds=None, mode="normal"):
        start = [0, 0]
        end = [0, 0]
        axis_ndx = self.axis_ndx
        other_ndx = 1 - axis_ndx

        bar_x = self.x + self.width / 2
        bar_y = self.y + self.height / 2

        # Draw the bar and endcaps
        gc.set_stroke_color(self.bar_color_)
        gc.set_line_width(self.bar_width)
        if self.orientation == "h":
            gc.move_to(self.x, bar_y)
            gc.line_to(self.x2, bar_y)
            gc.stroke_path()
            if self.endcaps:
                start_y = bar_y - self._cached_endcap_size / 2
                end_y = bar_y + self._cached_endcap_size / 2
                gc.move_to(self.x, start_y)
                gc.line_to(self.x, end_y)
                gc.move_to(self.x2, start_y)
                gc.line_to(self.x2, end_y)
            if self.num_ticks > 0:
                x_pts = linspace(self.x, self.x2,
                                 self.num_ticks + 2).astype(int)
                starts = zeros((len(x_pts), 2), dtype=int)
                starts[:, 0] = x_pts
                starts[:, 1] = bar_y - self._cached_tick_size / 2
                ends = starts.copy()
                ends[:, 1] = bar_y + self._cached_tick_size / 2
                gc.line_set(starts, ends)
        else:
            gc.move_to(bar_x, self.y)
            gc.line_to(bar_x, self.y2)
            if self.endcaps:
                start_x = bar_x - self._cached_endcap_size / 2
                end_x = bar_x + self._cached_endcap_size / 2
                gc.move_to(start_x, self.y)
                gc.line_to(end_x, self.y)
                gc.move_to(start_x, self.y2)
                gc.line_to(end_x, self.y2)
            if self.num_ticks > 0:
                y_pts = linspace(self.y, self.y2,
                                 self.num_ticks + 2).astype(int)
                starts = zeros((len(y_pts), 2), dtype=int)
                starts[:, 1] = y_pts
                starts[:, 0] = bar_x - self._cached_tick_size / 2
                ends = starts.copy()
                ends[:, 0] = bar_x + self._cached_tick_size / 2
                gc.line_set(starts, ends)
        gc.stroke_path()

        # Draw the slider
        pt = self.map_screen(self.value)
        if self.slider == "rect":
            gc.set_fill_color(self.slider_color_)
            gc.set_stroke_color(self.slider_border_)
            gc.set_line_width(self.slider_outline_width)
            rect = self._get_rect_slider_bounds()
            gc.rect(*rect)
            gc.draw_path()
        else:
            self._render_marker(gc, pt, self._cached_slider_size,
                                self.slider_(), self.custom_slider)

    def _get_rect_slider_bounds(self):
        """ Returns the (x, y, w, h) bounds of the rectangle representing the slider.
        Used for rendering and hit detection.
        """
        bar_x = self.x + self.width / 2
        bar_y = self.y + self.height / 2
        pt = self.map_screen(self.value)
        if self.orientation == "h":
            slider_height = self._cached_slider_size
            return (pt[0] - self.slider_thickness, bar_y - slider_height / 2,
                    self.slider_thickness, slider_height)
        else:
            slider_width = self._cached_slider_size
            return (bar_x - slider_width / 2, pt[1] - self.slider_thickness,
                    slider_width, self.slider_thickness)

    def _render_marker(self, gc, point, size, marker, custom_path):
        with gc:
            gc.begin_path()
            if marker.draw_mode == STROKE:
                gc.set_stroke_color(self.slider_color_)
                gc.set_line_width(self.slider_thickness)
            else:
                gc.set_fill_color(self.slider_color_)
                gc.set_stroke_color(self.slider_border_)
                gc.set_line_width(self.slider_outline_width)

            if hasattr(gc, "draw_marker_at_points") and \
                    (marker.__class__ != CustomMarker) and \
                    (gc.draw_marker_at_points([point], size, marker.kiva_marker) != 0):
                pass
            elif hasattr(gc, "draw_path_at_points"):
                if marker.__class__ != CustomMarker:
                    path = gc.get_empty_path()
                    marker.add_to_path(path, size)
                    mode = marker.draw_mode
                else:
                    path = custom_path
                    mode = STROKE
                if not marker.antialias:
                    gc.set_antialias(False)
                gc.draw_path_at_points([point], path, mode)
            else:
                if not marker.antialias:
                    gc.set_antialias(False)
                if marker.__class__ != CustomMarker:
                    gc.translate_ctm(*point)
                    # Kiva GCs have a path-drawing interface
                    marker.add_to_path(gc, size)
                    gc.draw_path(marker.draw_mode)
                else:
                    path = custom_path
                    gc.translate_ctm(*point)
                    gc.add_path(path)
                    gc.draw_path(STROKE)

    #------------------------------------------------------------------------
    # Interaction event handlers
    #------------------------------------------------------------------------

    def normal_left_down(self, event):
        if self.mouse_button == "left":
            return self._mouse_pressed(event)

    def dragging_left_up(self, event):
        if self.mouse_button == "left":
            return self._mouse_released(event)

    def normal_right_down(self, event):
        if self.mouse_button == "right":
            return self._mouse_pressed(event)

    def dragging_right_up(self, event):
        if self.mouse_button == "right":
            return self._mouse_released(event)

    def dragging_mouse_move(self, event):
        dx, dy = self._offset
        self.value = self.map_data(event.x - dx, event.y - dy)
        event.handled = True
        self.request_redraw()

    def dragging_mouse_leave(self, event):
        self.event_state = "normal"

    def _mouse_pressed(self, event):
        # Determine the slider bounds so we can hit test it
        pt = self.map_screen(self.value)
        if self.slider == "rect":
            x, y, w, h = self._get_rect_slider_bounds()
            x2 = x + w
            y2 = y + h
        else:
            x, y = pt
            size = self._cached_slider_size
            x -= size / 2
            y -= size / 2
            x2 = x + size
            y2 = y + size

        # Hit test both the slider and against the bar.  If the user has
        # clicked on the bar but outside of the slider, we set the _offset
        # and call dragging_mouse_move() to teleport the slider to the
        # mouse click position.
        if self.orientation == "v" and (x <= event.x <= x2):
            if not (y <= event.y <= y2):
                self._offset = (event.x - pt[0], 0)
                self.dragging_mouse_move(event)
            else:
                self._offset = (event.x - pt[0], event.y - pt[1])
        elif self.orientation == "h" and (y <= event.y <= y2):
            if not (x <= event.x <= x2):
                self._offset = (0, event.y - pt[1])
                self.dragging_mouse_move(event)
            else:
                self._offset = (event.x - pt[0], event.y - pt[1])
        else:
            # The mouse click missed the bar and the slider.
            return

        event.handled = True
        self.event_state = "dragging"
        return

    def _mouse_released(self, event):
        self.event_state = "normal"
        event.handled = True

    #------------------------------------------------------------------------
    # Private trait event handlers and property getters/setters
    #------------------------------------------------------------------------

    def _get_axis_ndx(self):
        if self.orientation == "h":
            return 0
        else:
            return 1

    def _get_slider_size(self):
        return self._cached_slider_size

    def _get_endcap_size(self):
        return self._cached_endcap_size

    def _get_tick_size(self):
        return self._cached_tick_size

    @on_trait_change("bounds,bounds_items")
    def _update_sizes(self):
        if self._slider_size_mode == "percent":
            if self.orientation == "h":
                self._cached_slider_size = int(self.height *
                                               self._slider_percent)
            else:
                self._cached_slider_size = int(self.width *
                                               self._slider_percent)
        if self._endcap_size_mode == "percent":
            if self.orientation == "h":
                self._cached_endcap_size = int(self.height *
                                               self._endcap_percent)
            else:
                self._cached_endcap_size = int(self.width *
                                               self._endcap_percent)

        return