コード例 #1
0
ファイル: view_pager.py プロジェクト: codelv/enaml-native
class ViewPager(ViewGroup):
    """Layout manager that allows the user to flip left and right through
    pages of data.

    """

    #: Set the currently selected page.
    current_index = d_(Int())

    #: Set the number of pages that should be retained to either side
    #: of the current page in the view hierarchy in an idle state.
    offscreen_page_limit = d_(Int())

    #: Enable or disable paging by swiping
    paging_enabled = d_(Bool(True))

    #: Set the margin between pages.
    page_margin = d_(Int(-1))

    #: Read only list of pages
    pages = Property()

    def _get_pages(self):
        return [c for c in self._children if isinstance(c, Fragment)]

    #: Transition
    transition = d_(
        Enum(
            "default",
            "accordion",
            "bg_to_fg",
            "fg_to_bg",
            "cube_in",
            "cube_out",
            "draw_from_back",
            "flip_horizontal",
            "flip_vertical",
            "depth_page",
            "parallax_page",
            "rotate_down",
            "rotate_up",
            "stack",
            "tablet",
            "zoom_in",
            "zoom_out",
            "zoom_out_slide",
        )
    )

    #: A reference to the ProxyLabel object.
    proxy = Typed(ProxyViewPager)

    # -------------------------------------------------------------------------
    # Observers
    # -------------------------------------------------------------------------
    @observe(
        "current_index",
        "offscreen_page_limit",
        "page_margin",
        "paging_enabled",
        "transition",
    )
    def _update_proxy(self, change):

        super()._update_proxy(change)
コード例 #2
0
ファイル: qt_field.py プロジェクト: zxyfanta/enaml
class QtField(QtControl, ProxyField):
    """ A Qt4 implementation of an Enaml ProxyField.

    """
    #: A reference to the widget created by the proxy.
    widget = Typed(QFocusLineEdit)

    #: A collapsing timer for auto sync text.
    _text_timer = Typed(QTimer)

    #: Cyclic notification guard. This a bitfield of multiple guards.
    _guard = Int(0)

    #--------------------------------------------------------------------------
    # Initialization API
    #--------------------------------------------------------------------------
    def create_widget(self):
        """ Creates the underlying QFocusLineEdit widget.

        """
        self.widget = QFocusLineEdit(self.parent_widget())

    def init_widget(self):
        """ Create and initialize the underlying widget.

        """
        super(QtField, self).init_widget()
        d = self.declaration
        if d.text:
            self.set_text(d.text)
        if d.mask:
            self.set_mask(d.mask)
        if d.placeholder:
            self.set_placeholder(d.placeholder)
        self.set_echo_mode(d.echo_mode)
        self.set_max_length(d.max_length)
        self.set_read_only(d.read_only)
        self.set_submit_triggers(d.submit_triggers)
        self.set_text_align(d.text_align)
        self.widget.textEdited.connect(self.on_text_edited)

    #--------------------------------------------------------------------------
    # Private API
    #--------------------------------------------------------------------------
    def _validate_and_apply(self):
        """ Validate and apply the text in the control.

        """
        d = self.declaration
        v = d.validator

        text = self.widget.text()
        if v and not v.validate(text):
            text = v.fixup(text)
            if not v.validate(text):
                return

        if text != self.widget.text():
            self.widget.setText(text)

        self._clear_error_state()
        d.text = text

    def _set_error_state(self):
        """ Set the error state of the widget.

        """
        # A temporary hack until styles are implemented
        if not self._guard & ERROR_FLAG:
            self._guard |= ERROR_FLAG
            s = u'QLineEdit { border: 2px solid red; background: rgb(255, 220, 220); }'
            self.widget.setStyleSheet(s)
            v = self.declaration.validator
            self.widget.setToolTip(v and v.message or u'')

    def _clear_error_state(self):
        """ Clear the error state of the widget.

        """
        # A temporary hack until styles are implemented
        if self._guard & ERROR_FLAG:
            self._guard &= ~ERROR_FLAG
            # Replace the widget's "error" stylesheet with
            # the one defined in the declaration
            self.refresh_style_sheet()
            self.widget.setToolTip(u'')

    def _maybe_valid(self, text):
        """ Get whether the text is valid or can be valid.

        Returns
        -------
        result : bool
            True if the text is valid, or can be made to be valid,
            False otherwise.

        """
        v = self.declaration.validator
        return v is None or v.validate(text) or v.validate(v.fixup(text))

    #--------------------------------------------------------------------------
    # Signal Handlers
    #--------------------------------------------------------------------------
    def on_submit_text(self):
        """ The signal handler for the text submit triggers.

        """
        self._guard |= TEXT_GUARD
        try:
            self._validate_and_apply()
        finally:
            self._guard &= ~TEXT_GUARD

    def on_text_edited(self):
        """ The signal handler for the 'textEdited' signal.

        """
        if not self._maybe_valid(self.widget.text()):
            self._set_error_state()
        else:
            self._clear_error_state()

        if self._text_timer is not None:
            self._text_timer.start()

    #--------------------------------------------------------------------------
    # ProxyField API
    #--------------------------------------------------------------------------
    def set_text(self, text):
        """ Set the text in the widget.

        """
        if not self._guard & TEXT_GUARD:
            self.widget.setText(text)
            self._clear_error_state()

    def set_mask(self, mask):
        """ Set the make for the widget.

        """
        self.widget.setInputMask(mask)

    def set_submit_triggers(self, triggers):
        """ Set the submit triggers for the widget.

        """
        widget = self.widget
        handler = self.on_submit_text
        try:
            widget.lostFocus.disconnect()
        except (TypeError, RuntimeError):  # was never connected
            pass
        try:
            widget.returnPressed.disconnect()
        except (TypeError, RuntimeError):  # was never connected
            pass
        if 'lost_focus' in triggers:
            widget.lostFocus.connect(handler)
        if 'return_pressed' in triggers:
            widget.returnPressed.connect(handler)
        if 'auto_sync' in triggers:
            if self._text_timer is None:
                timer = self._text_timer = QTimer()
                timer.setSingleShot(True)
                timer.setInterval(self.declaration.sync_time)
                timer.timeout.connect(handler)
        else:
            if self._text_timer is not None:
                self._text_timer.stop()
                self._text_timer = None

    def set_sync_time(self, time):
        """ Set the sync time for the widget.

        """
        if self._text_timer is not None:
            self._text_timer.setInterval(self.declaration.sync_time)

    def set_placeholder(self, text):
        """ Set the placeholder text of the widget.

        """
        self.widget.setPlaceholderText(text)

    def set_echo_mode(self, mode):
        """ Set the echo mode of the widget.

        """
        self.widget.setEchoMode(ECHO_MODES[mode])

    def set_max_length(self, length):
        """ Set the maximum text length in the widget.

        """
        if length <= 0 or length > 32767:
            length = 32767
        self.widget.setMaxLength(length)

    def set_read_only(self, read_only):
        """ Set whether or not the widget is read only.

        """
        self.widget.setReadOnly(read_only)

    def set_text_align(self, text_align):
        """ Set the alignment for the text in the field.

        """
        qt_align = ALIGN_OPTIONS[text_align]
        self.widget.setAlignment(qt_align)

    def field_text(self):
        """ Get the text stored in the widget.

        """
        return self.widget.text()
コード例 #3
0
ファイル: models.py プロジェクト: yairvillarp/inkcut
class JobInfo(Model):
    """ Job metadata """
    #: Controls
    done = Bool()
    cancelled = Bool()
    paused = Bool()

    #: Flags
    status = Enum('staged', 'waiting', 'running', 'error', 'approved',
                  'cancelled', 'complete').tag(config=True)

    #: Stats
    created = Instance(datetime).tag(config=True)
    started = Instance(datetime).tag(config=True)
    ended = Instance(datetime).tag(config=True)
    progress = Range(0, 100, 0).tag(config=True)
    data = Unicode().tag(config=True)
    count = Int().tag(config=True)

    #: Device speed in px/s
    speed = Float(strict=False).tag(config=True)
    #: Length in px
    length = Float(strict=False).tag(config=True)

    #: Estimates based on length and speed
    duration = Instance(timedelta, ()).tag(config=True)

    #: Units
    units = Enum('in', 'cm', 'm', 'ft').tag(config=True)

    #: Callback to open the approval dialog
    auto_approve = Bool().tag(config=True)
    request_approval = Callable()

    def __init__(self, *args, **kwargs):
        super(JobInfo, self).__init__(*args, **kwargs)
        self.created = self._default_created()

    def _default_created(self):
        return datetime.now()

    def _default_request_approval(self):
        """ Request approval using the current job """
        from inkcut.core.workbench import InkcutWorkbench
        workbench = InkcutWorkbench.instance()
        plugin = workbench.get_plugin("inkcut.job")
        return lambda: plugin.request_approval(plugin.job)

    def reset(self):
        """ Reset to initial states"""
        #: TODO: This is a stupid design
        self.progress = 0
        self.paused = False
        self.cancelled = False
        self.done = False
        self.status = 'staged'

    def _observe_done(self, change):
        if change['type'] == 'update':
            #: Increment count every time it's completed
            if self.done:
                self.count += 1

    @observe('length', 'speed')
    def _update_duration(self, change):
        if not self.length or not self.speed:
            self.duration = timedelta()
            return
        dt = self.length / self.speed
        self.duration = timedelta(seconds=dt)
コード例 #4
0
class ASTNode(Atom):
    """ The base class for Enaml ast nodes.

    """
    #: The line number in the .enaml file which generated the node.
    lineno = Int(-1)
コード例 #5
0
class PlotFormat(PlotUpdate):
    """base class corresponding to one graph or collection on axes"""
    plot_name = ReadOnly()

    append = Bool(False)
    remove = Bool(False)

    xcoord = Float()
    ycoord = Float()
    xind = Int()
    yind = Int()

    #    x_min=Float()
    #    x_max=Float()
    #    y_min=Float()
    #    y_max=Float()
    #
    #    def _default_x_min(self):
    #        return min(self.xdata)
    #
    #    def _default_x_max(self):
    #        return max(self.xdata)
    #
    #    def _default_y_min(self):
    #        return min(self.ydata)
    #
    #    def _default_y_max(self):
    #        return max(self.ydata)

    def do_autolim(self):
        if self.plotter.auto_xlim:
            self.plotter.x_min = float(
                min((self.plotter.x_min, nanmin(self.xdata))))
            self.plotter.x_max = float(
                max((self.plotter.x_max, nanmax(self.xdata))))
        else:
            self.plotter.set_xlim(self.plotter.x_min, self.plotter.x_max)
        if self.plotter.auto_ylim:
            self.plotter.y_min = float(
                min((self.plotter.y_min, nanmin(self.ydata))))
            self.plotter.y_max = float(
                max((self.plotter.y_max, nanmax(self.ydata))))
        else:
            self.plotter.set_ylim(self.plotter.y_min, self.plotter.y_max)
        if self.plotter.show_legend:
            self.plotter.legend()

    xdata = Array()
    ydata = Array()

    plot_type = Enum("line", "scatter", "multiline", "colormap", "vline",
                     "hline", "polygon", "cross_cursor")

    clt = Typed(Line2D)

    visible = Bool(True).tag(former="visible")

    def clt_values(self):
        if isinstance(self.clt, dict):
            return self.clt.values()
        return [self.clt]

    def plot_set(self, param):
        for clt in self.clt_values():
            simple_set(clt, self, get_tag(self, param, "former", param))

    @plot_observe("visible")
    def plot_update(self, change):
        """set the clt's parameter to the obj's value using clt's set function"""
        self.plot_set(change["name"])

    def remove_collection(self):
        if self.remove:
            if self.clt is not None:
                self.clt.remove()

    def __init__(self, **kwargs):
        plot_name = kwargs.pop(
            "plot_name", self.plot_type)  #defaulter(self, "plot_name", kwargs)
        plotter = kwargs["plotter"]
        self.plot_name = name_generator(
            plot_name, plotter.plot_dict,
            kwargs.get("plot_type", self.plot_type))
        super(PlotFormat, self).__init__(**kwargs)
        self.plotter.plot_dict[self.plot_name] = self

    @cached_property
    def view_window(self):
        with imports():
            from plot_format_e import Main
        view = Main(pltr=self.plotter)
        return view
コード例 #6
0
class VNA_Two_Tone_Pwr_Lyzer(VNA_Pwr_Lyzer):
    base_name = "vna_two_tone_lyzer"

    frq2 = Array().tag(unit="GHz", label="2nd frequency", sub=True)

    pwr2 = Array().tag(unit="dBm", label="2nd power", sub=True)

    frq2_ind = Int()

    pwr2_ind = Int()

    swp_type = Enum("pwr_first", "yoko_first")

    def _default_read_data(self):
        return read_data

    #def _observe_pwr_ind(self, change):
    #    reset_property(self, "Magcom")

    @tag_property(sub=True)
    def Magcom(self):
        if self.filter_type == "None":
            return self.MagcomData[:, :, self.frq2_ind, self.pwr2_ind]
        elif self.filter_type == "Fit":
            return self.MagAbsFit
        else:
            return self.MagcomFilt[self.indices, :, :]
        if self.bgsub_type == "Complex":
            return self.bgsub(Magcom)
        return Magcom[:, :, self.pwr_ind]

#array([[self.fft_filter_full(m, n, Magcom) for n in range(len(self.yoko))] for m in range(len(self.pwr))]).transpose()

    @private_property
    def MagcomFilt(self):
        if self.filt.filter_type == "FIR":
            return array([[
                self.filt.fir_filter(self.MagcomData[:, n, self.frq2_ind,
                                                     self.pwr2_ind])
                for n in self.flat_flux_indices
            ] for m in range(len(self.pwr))]).transpose()
        return squeeze(
            array([[[
                self.filt.fft_filter(self.MagcomData[:, n, m, self.pwr2_ind])
                for n in self.flat_flux_indices
            ]] for m in range(len(self.frq2))]).transpose())

    @tag_property(sub=True)
    def MagAbsFilt_sq(self):
        return absolute(self.MagcomFilt[:, :, self.frq2_ind, self.pwr2_ind,
                                        self.pwr_ind])**2

    @private_property
    def fit_params(self):
        if self.fitter.fit_params is None:
            self.fitter.full_fit(x=self.flux_axis[self.flat_flux_indices],
                                 y=self.MagAbsFilt_sq,
                                 indices=self.flat_indices,
                                 gamma=self.fitter.gamma)
            if self.calc_p_guess:
                self.fitter.make_p_guess(
                    self.flux_axis[self.flat_flux_indices],
                    y=self.MagAbsFilt_sq,
                    indices=self.flat_indices,
                    gamma=self.fitter.gamma)
        return self.fitter.fit_params

    @private_property
    def MagAbsFit(self):
        return sqrt(
            self.fitter.reconstruct_fit(self.flux_axis[self.flat_flux_indices],
                                        self.fit_params))
コード例 #7
0
ファイル: draw_image.py プロジェクト: mrinaleni/PyXRF
class DrawImageAdvanced(Atom):
    """
    This class performs 2D image rendering, such as showing multiple
    2D fitting or roi images based on user's selection.

    Attributes
    ----------
    img_data : dict
        dict of 2D array
    fig : object
        matplotlib Figure
    file_name : str
    stat_dict : dict
        determine which image to show
    data_dict : dict
        multiple data sets to plot, such as fit data, or roi data
    data_dict_keys : list
    data_opt : int
        index to show which data is chosen to plot
    dict_to_plot : dict
        selected data dict to plot, i.e., fitting data or roi is selected
    items_in_selected_group : list
        keys of dict_to_plot
    scale_opt : str
        linear or log plot
    color_opt : str
        orange or gray plot
    scaler_norm_dict : dict
        scaler normalization data, from data_dict
    scaler_items : list
        keys of scaler_norm_dict
    scaler_name_index : int
        index to select on GUI level
    scaler_data : None or numpy
        selected scaler data
    x_pos : list
        define data range in horizontal direction
    y_pos : list
        define data range in vertical direction
    pixel_or_pos : int
        index to choose plot with pixel (== 0) or with positions (== 1)
    grid_interpolate: bool
        choose to interpolate 2D image in terms of x,y or not
    limit_dict : Dict
        save low and high limit for image scaling
    """

    fig = Typed(Figure)
    stat_dict = Dict()
    data_dict = Dict()
    data_dict_keys = List()
    data_opt = Int(0)
    dict_to_plot = Dict()
    items_in_selected_group = List()
    items_previous_selected = List()

    scale_opt = Str('Linear')
    color_opt = Str('viridis')
    img_title = Str()

    scaler_norm_dict = Dict()
    scaler_items = List()
    scaler_name_index = Int()
    scaler_data = Typed(object)

    x_pos = List()
    y_pos = List()

    pixel_or_pos = Int(0)
    grid_interpolate = Bool(False)
    data_dict_default = Dict()
    limit_dict = Dict()
    range_dict = Dict()
    scatter_show = Bool(False)
    name_not_scalable = List()

    def __init__(self):
        self.fig = plt.figure(figsize=(3, 2))
        matplotlib.rcParams['axes.formatter.useoffset'] = True

        # Do not apply scaler norm on following data
        self.name_not_scalable = [
            'r2_adjust', 'alive', 'dead', 'elapsed_time', 'scaler_alive',
            'i0_time', 'time', 'time_diff', 'dwell_time'
        ]

    def data_dict_update(self, change):
        """
        Observer function to be connected to the fileio model
        in the top-level gui.py startup

        Parameters
        ----------
        changed : dict
            This is the dictionary that gets passed to a function
            with the @observe decorator
        """
        self.data_dict = change['value']

    def set_default_dict(self, data_dict):
        self.data_dict_default = copy.deepcopy(data_dict)

    @observe('data_dict')
    def init_plot_status(self, change):
        scaler_groups = [
            v for v in list(self.data_dict.keys()) if 'scaler' in v
        ]
        if len(scaler_groups) > 0:
            # self.scaler_group_name = scaler_groups[0]
            self.scaler_norm_dict = self.data_dict[scaler_groups[0]]
            # for GUI purpose only
            self.scaler_items = []
            self.scaler_items = list(self.scaler_norm_dict.keys())
            self.scaler_items.sort()
            self.scaler_data = None

        # init of pos values
        self.pixel_or_pos = 0

        if 'positions' in self.data_dict:
            try:
                logger.debug(
                    f"Position keys: {list(self.data_dict['positions'].keys())}"
                )
                self.x_pos = list(self.data_dict['positions']['x_pos'][0, :])
                self.y_pos = list(self.data_dict['positions']['y_pos'][:, -1])
                # when we use imshow, the x and y start at lower left,
                # so flip y, we want y starts from top left
                self.y_pos.reverse()

            except KeyError:
                pass
        else:
            self.x_pos = []
            self.y_pos = []

        self.get_default_items()  # use previous defined elements as default
        logger.info('Use previously selected items as default: {}'.format(
            self.items_previous_selected))

        # initiate the plotting status once new data is coming
        self.reset_to_default()
        self.data_dict_keys = []
        self.data_dict_keys = list(self.data_dict.keys())
        logger.debug(
            'The following groups are included for 2D image display: {}'.
            format(self.data_dict_keys))

        self.show_image()

    def reset_to_default(self):
        """Set variables to default values as initiated.
        """
        self.data_opt = 0
        # init of scaler for normalization
        self.scaler_name_index = 0
        self.plot_deselect_all()

    def get_default_items(self):
        """Add previous selected items as default.
        """
        if len(self.items_previous_selected) != 0:
            default_items = {}
            for item in self.items_previous_selected:
                for v, k in self.data_dict.items():
                    if item in k:
                        default_items[item] = k[item]
            self.data_dict['use_default_selection'] = default_items

    @observe('data_opt')
    def _update_file(self, change):
        try:
            if self.data_opt == 0:
                self.dict_to_plot = {}
                self.items_in_selected_group = []
                self.set_stat_for_all(bool_val=False)
                self.img_title = ''
            elif self.data_opt > 0:
                # self.set_stat_for_all(bool_val=False)
                plot_item = sorted(self.data_dict_keys)[self.data_opt - 1]
                self.img_title = str(plot_item)
                self.dict_to_plot = self.data_dict[plot_item]
                self.set_stat_for_all(bool_val=False)

                self.update_img_wizard_items()
                self.get_default_items(
                )  # get default elements every time when fitting is done

        except IndexError:
            pass

    @observe('scaler_name_index')
    def _get_scaler_data(self, change):
        if change['type'] == 'create':
            return

        if self.scaler_name_index == 0:
            self.scaler_data = None
        else:
            try:
                scaler_name = self.scaler_items[self.scaler_name_index - 1]
            except IndexError:
                scaler_name = None
            if scaler_name:
                self.scaler_data = self.scaler_norm_dict[scaler_name]
                logger.info('Use scaler data to normalize, '
                            'and the shape of scaler data is {}, '
                            'with (low, high) as ({}, {})'.format(
                                self.scaler_data.shape,
                                np.min(self.scaler_data),
                                np.max(self.scaler_data)))
        self.set_low_high_value(
        )  # reset low high values based on normalization
        self.show_image()
        self.update_img_wizard_items()

    def update_img_wizard_items(self):
        """This is for GUI purpose only.
        Table items will not be updated if list items keep the same.
        """
        self.items_in_selected_group = []
        self.items_in_selected_group = list(self.dict_to_plot.keys())

    def format_img_wizard_limit(self, value):
        """
        This function is used for formatting of range values in 'Image Wizard'.
        The presentation of the number was tweaked so that it is nicely formatted
           in the enaml field with adequate precision.

        ..note::

        The function is called externally from 'enaml' code.

        Parameters:
        ===========
        value : float
            The value to be formatted

        Returns:
        ========
        str - the string representation of the floating point variable
        """
        if value != 0:
            value_log10 = math.log10(abs(value))
        else:
            value_log10 = 0
        if (value_log10 > 3) or (value_log10 < -3):
            return f"{value:.6e}"
        return f"{value:.6f}"

    @observe('scale_opt', 'color_opt')
    def _update_scale(self, change):
        if change['type'] != 'create':
            self.show_image()

    @observe('pixel_or_pos')
    def _update_pp(self, change):
        self.show_image()

    @observe('grid_interpolate')
    def _update_gi(self, change):
        self.show_image()

    def plot_select_all(self):
        self.set_stat_for_all(bool_val=True)

    def plot_deselect_all(self):
        self.set_stat_for_all(bool_val=False)

    @observe('scatter_show')
    def _change_image_plot_method(self, change):
        if change['type'] != 'create':
            self.show_image()

    def set_stat_for_all(self, bool_val=False):
        """
        Set plotting status for all the 2D images, including low and high values.
        """
        self.stat_dict.clear()
        self.stat_dict = {k: bool_val for k in self.dict_to_plot.keys()}

        self.limit_dict.clear()
        self.limit_dict = {
            k: {
                'low': 0.0,
                'high': 100.0
            }
            for k in self.dict_to_plot.keys()
        }

        self.set_low_high_value()

    def set_low_high_value(self):
        """Set default low and high values based on normalization for each image.
        """
        # do not apply scaler norm on not scalable data
        self.range_dict.clear()
        for data_name in self.dict_to_plot.keys():
            data_arr = normalize_data_by_scaler(
                self.dict_to_plot[data_name],
                self.scaler_data,
                data_name=data_name,
                name_not_scalable=self.name_not_scalable)
            lowv = np.min(data_arr)
            highv = np.max(data_arr)
            self.range_dict[data_name] = {
                'low': lowv,
                'low_default': lowv,
                'high': highv,
                'high_default': highv
            }

    def reset_low_high(self, name):
        """Reset low and high value to default based on normalization.
        """
        self.range_dict[name]['low'] = self.range_dict[name]['low_default']
        self.range_dict[name]['high'] = self.range_dict[name]['high_default']
        self.limit_dict[name]['low'] = 0.0
        self.limit_dict[name]['high'] = 100.0
        self.update_img_wizard_items()
        self.show_image()

    def show_image(self):
        self.fig.clf()
        stat_temp = self.get_activated_num()
        stat_temp = OrderedDict(
            sorted(six.iteritems(stat_temp), key=lambda x: x[0]))

        # Check if positions data is available. Positions data may be unavailable
        # (not recorded in HDF5 file) if experiment is has not been completed.
        # While the data from the completed part of experiment may still be used,
        # plotting vs. x-y or scatter plot may not be displayed.
        positions_data_available = False
        if 'positions' in self.data_dict.keys():
            positions_data_available = True

        # Create local copies of self.pixel_or_pos, self.scatter_show and self.grid_interpolate
        pixel_or_pos_local = self.pixel_or_pos
        scatter_show_local = self.scatter_show
        grid_interpolate_local = self.grid_interpolate

        # Disable plotting vs x-y coordinates if 'positions' data is not available
        if not positions_data_available:
            if pixel_or_pos_local:
                pixel_or_pos_local = 0  # Switch to plotting vs. pixel number
                logger.error(
                    "'Positions' data is not available. Plotting vs. x-y coordinates is disabled"
                )
            if scatter_show_local:
                scatter_show_local = False  # Switch to plotting vs. pixel number
                logger.error(
                    "'Positions' data is not available. Scatter plot is disabled."
                )
            if grid_interpolate_local:
                grid_interpolate_local = False  # Switch to plotting vs. pixel number
                logger.error(
                    "'Positions' data is not available. Interpolation is disabled."
                )

        low_lim = 1e-4  # define the low limit for log image
        plot_interp = 'Nearest'

        if self.scaler_data is not None:
            if np.count_nonzero(self.scaler_data) == 0:
                logger.warning('scaler is zero - scaling was not applied')
            elif len(self.scaler_data[self.scaler_data == 0]) > 0:
                logger.warning('scaler data has zero values')

        grey_use = self.color_opt

        ncol = int(np.ceil(np.sqrt(len(stat_temp))))
        try:
            nrow = int(np.ceil(len(stat_temp) / float(ncol)))
        except ZeroDivisionError:
            ncol = 1
            nrow = 1

        a_pad_v = 0.8
        a_pad_h = 0.5

        grid = ImageGrid(self.fig,
                         111,
                         nrows_ncols=(nrow, ncol),
                         axes_pad=(a_pad_v, a_pad_h),
                         cbar_location='right',
                         cbar_mode='each',
                         cbar_size='7%',
                         cbar_pad='2%',
                         share_all=True)

        def _compute_equal_axes_ranges(x_min, x_max, y_min, y_max):
            """
            Compute ranges for x- and y- axes of the plot. Make sure that the ranges for x- and y-axes are
            always equal and fit the maximum of the ranges for x and y values:
                  max(abs(x_max-x_min), abs(y_max-y_min))
            The ranges are set so that the data is always centered in the middle of the ranges

            Parameters
            ----------

            x_min, x_max, y_min, y_max : float
                lower and upper boundaries of the x and y values

            Returns
            -------

            x_axis_min, x_axis_max, y_axis_min, y_axis_max : float
                lower and upper boundaries of the x- and y-axes ranges
            """

            x_axis_min, x_axis_max, y_axis_min, y_axis_max = x_min, x_max, y_min, y_max
            x_range, y_range = abs(x_max - x_min), abs(y_max - y_min)
            if x_range > y_range:
                y_center = (y_max + y_min) / 2
                y_axis_max = y_center + x_range / 2
                y_axis_min = y_center - x_range / 2
            else:
                x_center = (x_max + x_min) / 2
                x_axis_max = x_center + y_range / 2
                x_axis_min = x_center - y_range / 2

            return x_axis_min, x_axis_max, y_axis_min, y_axis_max

        def _adjust_data_range_using_min_ratio(c_min,
                                               c_max,
                                               c_axis_range,
                                               *,
                                               min_ratio=0.01):
            """
            Adjust the range for plotted data along one axis (x or y). The adjusted range is
            applied to the 'extend' attribute of imshow(). The adjusted range is always greater
            than 'axis_range * min_ratio'. Such transformation has no physical meaning
            and performed for aesthetic reasons: stretching the image presentation of
            a scan with only a few lines (1-3) greatly improves visibility of data.

            Parameters
            ----------

            c_min, c_max : float
                boundaries of the data range (along x or y axis)
            c_axis_range : float
                range presented along the same axis

            Returns
            -------

            cmin, c_max : float
                adjusted boundaries of the data range
            """
            c_range = c_max - c_min
            if c_range < c_axis_range * min_ratio:
                c_center = (c_max + c_min) / 2
                c_new_range = c_axis_range * min_ratio
                c_min = c_center - c_new_range / 2
                c_max = c_center + c_new_range / 2
            return c_min, c_max

        for i, (k, v) in enumerate(six.iteritems(stat_temp)):

            data_dict = normalize_data_by_scaler(
                data_in=self.dict_to_plot[k],
                scaler=self.scaler_data,
                data_name=k,
                name_not_scalable=self.name_not_scalable)

            if pixel_or_pos_local or scatter_show_local:

                # xd_min, xd_max, yd_min, yd_max = min(self.x_pos), max(self.x_pos),
                #     min(self.y_pos), max(self.y_pos)
                x_pos_2D = self.data_dict['positions']['x_pos']
                y_pos_2D = self.data_dict['positions']['y_pos']
                xd_min, xd_max, yd_min, yd_max = x_pos_2D.min(), x_pos_2D.max(
                ), y_pos_2D.min(), y_pos_2D.max()
                xd_axis_min, xd_axis_max, yd_axis_min, yd_axis_max = \
                    _compute_equal_axes_ranges(xd_min, xd_max, yd_min, yd_max)

                xd_min, xd_max = _adjust_data_range_using_min_ratio(
                    xd_min, xd_max, xd_axis_max - xd_axis_min)
                yd_min, yd_max = _adjust_data_range_using_min_ratio(
                    yd_min, yd_max, yd_axis_max - yd_axis_min)

                # Adjust the direction of each axis depending on the direction in which encoder values changed
                #   during the experiment. Data is plotted starting from the upper-right corner of the plot
                if x_pos_2D[0, 0] > x_pos_2D[0, -1]:
                    xd_min, xd_max, xd_axis_min, xd_axis_max = xd_max, xd_min, xd_axis_max, xd_axis_min
                if y_pos_2D[0, 0] > y_pos_2D[-1, 0]:
                    yd_min, yd_max, yd_axis_min, yd_axis_max = yd_max, yd_min, yd_axis_max, yd_axis_min

            else:

                yd, xd = data_dict.shape

                xd_min, xd_max, yd_min, yd_max = 0, xd, 0, yd
                if (yd <= math.floor(xd / 100)) and (xd >= 200):
                    yd_min, yd_max = -math.floor(xd / 200), math.ceil(xd / 200)
                if (xd <= math.floor(yd / 100)) and (yd >= 200):
                    xd_min, xd_max = -math.floor(yd / 200), math.ceil(yd / 200)

                xd_axis_min, xd_axis_max, yd_axis_min, yd_axis_max = \
                    _compute_equal_axes_ranges(xd_min, xd_max, yd_min, yd_max)

            if self.scale_opt == 'Linear':

                low_ratio = self.limit_dict[k]['low'] / 100.0
                high_ratio = self.limit_dict[k]['high'] / 100.0
                if self.scaler_data is None:
                    minv = self.range_dict[k]['low']
                    maxv = self.range_dict[k]['high']
                else:
                    # Unfortunately, the new normalization procedure requires to recalculate min and max values
                    minv = np.min(data_dict)
                    maxv = np.max(data_dict)
                low_limit = (maxv - minv) * low_ratio + minv
                high_limit = (maxv - minv) * high_ratio + minv

                # Set some minimum range for the colorbar (otherwise it will have white fill)
                if math.isclose(low_limit, high_limit, abs_tol=2e-20):
                    if abs(low_limit) < 1e-20:  # The value is zero
                        dv = 1e-20
                    else:
                        dv = math.fabs(low_limit * 0.01)
                    high_limit += dv
                    low_limit -= dv

                if not scatter_show_local:
                    if grid_interpolate_local:
                        data_dict, _, _ = grid_interpolate(
                            data_dict, self.data_dict['positions']['x_pos'],
                            self.data_dict['positions']['y_pos'])
                    im = grid[i].imshow(data_dict,
                                        cmap=grey_use,
                                        interpolation=plot_interp,
                                        extent=(xd_min, xd_max, yd_max,
                                                yd_min),
                                        origin='upper',
                                        clim=(low_limit, high_limit))
                    grid[i].set_ylim(yd_axis_max, yd_axis_min)
                else:
                    xx = self.data_dict['positions']['x_pos']
                    yy = self.data_dict['positions']['y_pos']

                    # The following condition prevents crash if different file is loaded while
                    #    the scatter plot is open (PyXRF specific issue)
                    if data_dict.shape == xx.shape and data_dict.shape == yy.shape:
                        im = grid[i].scatter(
                            xx,
                            yy,
                            c=data_dict,
                            marker='s',
                            s=500,
                            alpha=1.0,  # Originally: alpha=0.8
                            cmap=grey_use,
                            vmin=low_limit,
                            vmax=high_limit,
                            linewidths=1,
                            linewidth=0)
                        grid[i].set_ylim(yd_axis_max, yd_axis_min)

                grid[i].set_xlim(xd_axis_min, xd_axis_max)

                grid_title = k
                grid[i].text(0,
                             1.01,
                             grid_title,
                             ha='left',
                             va='bottom',
                             transform=grid[i].axes.transAxes)

                grid.cbar_axes[i].colorbar(im)
                im.colorbar.formatter = im.colorbar.cbar_axis.get_major_formatter(
                )
                # im.colorbar.ax.get_xaxis().set_ticks([])
                # im.colorbar.ax.get_xaxis().set_ticks([], minor=True)
                grid.cbar_axes[i].ticklabel_format(style='sci',
                                                   scilimits=(-3, 4),
                                                   axis='both')

                #  Do not remove this code, may be useful in the future (Dmitri G.) !!!
                #  Print label for colorbar
                # cax = grid.cbar_axes[i]
                # axis = cax.axis[cax.orientation]
                # axis.label.set_text("$[a.u.]$")

            else:

                maxz = np.max(data_dict)
                # Set some reasonable minimum range for the colorbar
                #   Zeros or negative numbers will be shown in white
                if maxz <= 1e-30:
                    maxz = 1

                if not scatter_show_local:
                    if grid_interpolate_local:
                        data_dict, _, _ = grid_interpolate(
                            data_dict, self.data_dict['positions']['x_pos'],
                            self.data_dict['positions']['y_pos'])
                    im = grid[i].imshow(data_dict,
                                        norm=LogNorm(vmin=low_lim * maxz,
                                                     vmax=maxz,
                                                     clip=True),
                                        cmap=grey_use,
                                        interpolation=plot_interp,
                                        extent=(xd_min, xd_max, yd_max,
                                                yd_min),
                                        origin='upper',
                                        clim=(low_lim * maxz, maxz))
                    grid[i].set_ylim(yd_axis_max, yd_axis_min)
                else:
                    im = grid[i].scatter(
                        self.data_dict['positions']['x_pos'],
                        self.data_dict['positions']['y_pos'],
                        norm=LogNorm(vmin=low_lim * maxz, vmax=maxz,
                                     clip=True),
                        c=data_dict,
                        marker='s',
                        s=500,
                        alpha=1.0,  # Originally: alpha=0.8
                        cmap=grey_use,
                        linewidths=1,
                        linewidth=0)
                    grid[i].set_ylim(yd_axis_min, yd_axis_max)

                grid[i].set_xlim(xd_axis_min, xd_axis_max)

                grid_title = k
                grid[i].text(0,
                             1.01,
                             grid_title,
                             ha='left',
                             va='bottom',
                             transform=grid[i].axes.transAxes)

                grid.cbar_axes[i].colorbar(im)
                im.colorbar.formatter = im.colorbar.cbar_axis.get_major_formatter(
                )
                im.colorbar.ax.get_xaxis().set_ticks([])
                im.colorbar.ax.get_xaxis().set_ticks([], minor=True)
                im.colorbar.cbar_axis.set_minor_formatter(
                    mticker.LogFormatter())

            grid[i].get_xaxis().set_major_locator(
                mticker.MaxNLocator(nbins="auto"))
            grid[i].get_yaxis().set_major_locator(
                mticker.MaxNLocator(nbins="auto"))

            grid[i].get_xaxis().get_major_formatter().set_useOffset(False)
            grid[i].get_yaxis().get_major_formatter().set_useOffset(False)

        self.fig.suptitle(self.img_title, fontsize=20)
        self.fig.canvas.draw_idle()

    def get_activated_num(self):
        """Collect the selected items for plotting.
        """
        current_items = {
            k: v
            for (k, v) in six.iteritems(self.stat_dict) if v is True
        }
        return current_items

    def record_selected(self):
        """Save the list of items in cache for later use.
        """
        self.items_previous_selected = [
            k for (k, v) in self.stat_dict.items() if v is True
        ]
        logger.info('Items are set as default: {}'.format(
            self.items_previous_selected))
        self.data_dict['use_default_selection'] = {
            k: self.dict_to_plot[k]
            for k in self.items_previous_selected
        }
        self.data_dict_keys = list(self.data_dict.keys())
コード例 #8
0
ファイル: nidaq.py プロジェクト: buranconsult/psiexperiment
class NIDAQEngine(Engine):
    '''
    Hardware interface

    The tasks are started in the order they are configured. Most NI devices can
    only support a single hardware-timed task of a specified type (e.g., analog
    input, analog output, digital input, digital output are all unique task
    types).
    '''
    #: Name of the engine. This is used for logging and configuration purposes
    #: (we can have multiple NIDAQ engines if we need to define separate sets
    #: of tasks (e.g., if we have more than one multifunction DAQ card).
    engine_name = 'nidaq'

    #: Flag indicating whether engine was configured
    _configured = Bool(False)

    #: Poll period (in seconds). This defines how often callbacks for the
    #: analog outputs are notified (i.e., to generate additional samples for
    #: playout).  If the poll period is too long, then the analog output may
    #: run out of samples.
    hw_ao_monitor_period = d_(Float(1)).tag(metadata=True)

    #: Size of buffer (in seconds). This defines how much data is pregenerated
    #: for the buffer before starting acquisition. This is impotant because
    hw_ao_buffer_size = d_(Float(10)).tag(metadata=True)

    #: Even though data is written to the analog outputs, it is buffered in
    #: computer memory until it's time to be transferred to the onboard buffer
    #: of the NI acquisition card. NI-DAQmx handles this behind the scenes
    #: (i.e., when the acquisition card needs additional samples, NI-DAQmx will
    #: transfer the next chunk of data from the computer memory). We can
    #: overwrite data that's been buffered in computer memory (e.g., so we can
    #: insert a target in response to a nose-poke). However, we cannot
    #: overwrite data that's already been transfered to the onboard buffer. So,
    #: the onboard buffer size determines how quickly we can change the analog
    #: output in response to an event.
    hw_ao_onboard_buffer = d_(Int(4095)).tag(metadata=True)
    # TODO: This is not configurable on every card. How do we know if it's
    # configurable?

    #: Since any function call takes a small fraction of time (e.g., nanoseconds
    #: to milliseconds), we can't simply overwrite data starting at
    #: hw_ao_onboard_buffer+1. By the time the function calls are complete, the
    #: DAQ probably has already transferred a couple hundred samples to the
    #: buffer. This parameter will likely need some tweaking (i.e., only you can
    #: determine an appropriate value for this based on the needs of your
    #: program).
    hw_ao_min_writeahead = d_(Int(8191 + 1000)).tag(metadata=True)

    #: Total samples written to the buffer.
    total_samples_written = Int(0)

    _tasks = Typed(dict)
    _task_done = Typed(dict)
    _callbacks = Typed(dict)
    _timers = Typed(dict)
    _uint32 = Typed(ctypes.c_uint32)
    _uint64 = Typed(ctypes.c_uint64)
    _int32 = Typed(ctypes.c_int32)

    ao_fs = Typed(float).tag(metadata=True)
    ai_fs = Typed(float).tag(metadata=True)

    terminal_mode_map = {
        'differential': mx.DAQmx_Val_Diff,
        'pseudodifferential': mx.DAQmx_Val_PseudoDiff,
        'RSE': mx.DAQmx_Val_RSE,
        'NRSE': mx.DAQmx_Val_NRSE,
        'default': mx.DAQmx_Val_Cfg_Default,
    }

    terminal_coupling_map = {
        None: None,
        'AC': mx.DAQmx_Val_AC,
        'DC': mx.DAQmx_Val_DC,
        'ground': mx.DAQmx_Val_GND,
    }

    # This defines the function for the clock that synchronizes the tasks.
    sample_time = Callable()

    instances = Value([])

    def __init__(self, *args, **kw):
        super().__init__(*args, **kw)
        self.instances.append(self)

        self._tasks = {}
        self._callbacks = {}
        self._timers = {}
        self._configured = False

    def configure(self, active=True):
        log.debug('Configuring {} engine'.format(self.name))

        counter_channels = self.get_channels('counter', active=active)
        sw_do_channels = self.get_channels('digital', 'output', 'software',
                                           active=active)
        hw_ai_channels = self.get_channels('analog', 'input', 'hardware',
                                           active=active)
        hw_di_channels = self.get_channels('digital', 'input', 'hardware',
                                           active=active)
        hw_ao_channels = self.get_channels('analog', 'output', 'hardware',
                                           active=active)

        if counter_channels:
            log.debug('Configuring counter channels')
            self.configure_counters(counter_channels)

        if sw_do_channels:
            log.debug('Configuring SW DO channels')
            self.configure_sw_do(sw_do_channels)

        if hw_ai_channels:
            log.debug('Configuring HW AI channels')
            self.configure_hw_ai(hw_ai_channels)

        if hw_di_channels:
            raise NotImplementedError
            #log.debug('Configuring HW DI channels')
            #lines = ','.join(get_channel_property(hw_di_channels, 'channel', True))
            #names = get_channel_property(hw_di_channels, 'name', True)
            #fs = get_channel_property(hw_di_channels, 'fs')
            #start_trigger = get_channel_property(hw_di_channels, 'start_trigger')
            ## Required for M-series to enable hardware-timed digital
            ## acquisition. TODO: Make this a setting that can be configured
            ## since X-series doesn't need this hack.
            #device = hw_di_channels[0].channel.strip('/').split('/')[0]
            #clock = '/{}/Ctr0'.format(device)
            #self.configure_hw_di(fs, lines, names, start_trigger, clock)

        # Configure the analog output last because acquisition is synced with
        # the analog output signal (i.e., when the analog output starts, the
        # analog input begins acquiring such that sample 0 of the input
        # corresponds with sample 0 of the output).

        # TODO: eventually we should be able to inspect the  'start_trigger'
        # property on the channel configuration to decide the order in which the
        # tasks are started.
        if hw_ao_channels:
            log.debug('Configuring HW AO channels')
            self.configure_hw_ao(hw_ao_channels)

        # Choose sample clock based on what channels have been configured.
        if hw_ao_channels:
            self.sample_time = self.ao_sample_time
        elif hw_ai_channels:
            self.sample_time = self.ai_sample_time

        # Configure task done events so that we can fire a callback if
        # acquisition is done.
        self._task_done = {}
        for name, task in self._tasks.items():
            def cb(task, s, cb_data):
                nonlocal name
                self.task_complete(name)
                return 0
            cb_ptr = mx.DAQmxDoneEventCallbackPtr(cb)
            mx.DAQmxRegisterDoneEvent(task, 0, cb_ptr, None)
            task._done_cb_ptr_engine = cb_ptr
            self._task_done[name] = False

        super().configure()

        # Required by start. This allows us to do the configuration
        # on the fly when starting the engines if the configure method hasn't
        # been called yet.
        self._configured = True
        log.debug('Completed engine configuration')

    def task_complete(self, task_name):
        log.debug('Task %s complete', task_name)
        self._task_done[task_name] = True
        task = self._tasks[task_name]

        # We have frozen the initial arguments (in the case of hw_ai_helper,
        # that would be cb, channels, discard; in the case of hw_ao_helper,
        # that would be cb) using functools.partial and need to provide task,
        # cb_samples and cb_data. For hw_ai_helper, setting cb_samples to 1
        # means that we read all remaning samples, regardless of whether they
        # fit evenly into a block of samples. The other two arguments
        # (event_type and cb_data) are required of the function signature by
        # NIDAQmx but are unused.
        try:
            task._cb(task, None, 1, None)
        except Exception as e:
            log.exception(e)

        # Only check to see if hardware-timed tasks are complete.
        # Software-timed tasks must be explicitly canceled by the user.
        done = [v for t, v in self._task_done.items() if t.startswith('hw')]
        if all(done):
            for cb in self._callbacks.get('done', []):
                cb()

    def configure_counters(self, channels):
        task = setup_counters(channels)
        self._tasks['counter'] = task

    def configure_hw_ao(self, channels):
        '''
        Initialize hardware-timed analog output

        Parameters
        ----------
        fs : float
            Sampling frequency of output (e.g., 100e3).
        lines : str
            Analog output lines to use (e.gk., 'Dev1/ao0:4' to specify a range of
            lines or 'Dev1/ao0,Dev1/ao4' to specify specific lines).
        expected_range : (float, float)
            Tuple of upper/lower end of expected range. The maximum range
            allowed by most NI devices is (-10, 10). Some devices (especially
            newer ones) will optimize the output resolution based on the
            expected range of the signal.
        '''
        task = setup_hw_ao(channels, self.hw_ao_buffer_size,
                           self.hw_ao_monitor_period, self.hw_ao_callback,
                           '{}_hw_ao'.format(self.name))
        self._tasks['hw_ao'] = task
        self.ao_fs = task._fs
        for channel in channels:
            channel.fs = task._fs
        self.total_samples_written = 0

    def configure_hw_ai(self, channels):
        task_name = '{}_hw_ai'.format(self.name)
        task = setup_hw_ai(channels, self.hw_ai_monitor_period,
                           self._hw_ai_callback, task_name)
        self._tasks['hw_ai'] = task
        self.ai_fs = task._fs

    def configure_sw_ao(self, lines, expected_range, names=None,
                        initial_state=None):
        raise NotImplementedError
        if initial_state is None:
            initial_state = np.zeros(len(names), dtype=np.double)
        task_name = '{}_sw_ao'.format(self.name)
        task = setup_sw_ao(lines, expected_range, task_name)
        task._names = verify_channel_names(task, names)
        task._devices = device_list(task)
        self._tasks['sw_ao'] = task
        self.write_sw_ao(initial_state)

    def configure_hw_di(self, fs, lines, names=None, trigger=None, clock=None):
        raise NotImplementedError
        callback_samples = int(self.hw_ai_monitor_period*fs)
        task_name = '{}_hw_di'.format(self.name)
        task, clock_task = setup_hw_di(fs, lines, self._hw_di_callback,
                                       callback_samples, trigger, clock,
                                       task_name)
        task._names = verify_channel_names(task, names)
        task._devices = device_list(task)
        task._fs = fs
        if clock_task is not None:
            self._tasks['hw_di_clock'] = clock_task
        self._tasks['hw_di'] = task

    def configure_hw_do(self, fs, lines, names):
        raise NotImplementedError

    def configure_sw_do(self, channels):
        task_name = '{}_sw_do'.format(self.name)
        task = setup_sw_do(channels, task_name)
        self._tasks['sw_do'] = task
        initial_state = np.zeros(len(channels), dtype=np.uint8)
        self.write_sw_do(initial_state)

    def configure_et(self, lines, clock, names=None):
        '''
        Setup change detection with high-precision timestamps

        Anytime a rising or falling edge is detected on one of the specified
        lines, a timestamp based on the specified clock will be captured. For
        example, if the clock is 'ao/SampleClock', then the timestamp will be
        the number of samples played at the point when the line changed state.

        Parameters
        ----------
        lines : string
            Digital lines (in NI-DAQmx syntax, e.g., 'Dev1/port0/line0:4') to
            monitor.
        clock : string
            Reference clock from which timestamps will be drawn.
        names : string (optional)
            Aliases for the lines. When aliases are provided, registered
            callbacks will receive the alias for the line instead of the
            NI-DAQmx notation.

        Notes
        -----
        Be aware of the limitations of your device. All X-series devices support
        change detection on all ports; however, only some M-series devices do
        (and then, only on port 0).
        '''
        # Find out which device the lines are from. Use this to configure the
        # event timer. Right now we don't want to deal with multi-device event
        # timers. If there's more than one device, then we should configure each
        # separately.
        raise NotImplementedError

        # TODO: How to determine sampling rate of task?
        names = channel_names('digital', lines, names)
        devices = device_list(lines, 'digital')
        if len(devices) != 1:
            raise ValueError('Cannot configure multi-device event timer')

        trigger = '/{}/ChangeDetectionEvent'.format(devices[0])
        counter = '/{}/Ctr0'.format(devices[0])
        task_name = '{}_et'.format(self.name)
        et_task = setup_event_timer(trigger, counter, clock, task_name)
        task_name = '{}_cd'.format(self.name)
        cd_task = setup_change_detect_callback(lines, self._et_fired, et_task,
                                               names, task_name)
        cd_task._names = names
        self._tasks['et_task'] = et_task
        self._tasks['cd_task'] = cd_task

    def _get_channel_slice(self, task_name, channel_names):
        if channel_names is None:
            return Ellipsis
        else:
            return self._tasks[task_name]._names.index(channel_names)

    def register_done_callback(self, callback):
        self._callbacks.setdefault('done', []).append(callback)

    def register_ao_callback(self, callback, channel_name=None):
        s = self._get_channel_slice('hw_ao', channel_name)
        self._callbacks.setdefault('ao', []).append((channel_name, s, callback))

    def register_ai_callback(self, callback, channel_name=None):
        s = self._get_channel_slice('hw_ai', channel_name)
        self._callbacks.setdefault('ai', []).append((channel_name, s, callback))

    def register_di_callback(self, callback, channel_name=None):
        s = self._get_channel_slice('hw_di', channel_name)
        self._callbacks.setdefault('di', []).append((channel_name, s, callback))

    def register_et_callback(self, callback, channel_name=None):
        s = self._get_channel_slice('cd_task', channel_name)
        self._callbacks.setdefault('et', []).append((channel_name, s, callback))

    def unregister_done_callback(self, callback):
        try:
            self._callbacks['done'].remove(callback)
        except KeyError:
            log.warning('Callback no longer exists.')

    def unregister_ao_callback(self, callback, channel_name):
        try:
            s = self._get_channel_slice('hw_ao', channel_name)
            self._callbacks['ao'].remove((channel_name, s, callback))
        except (KeyError, AttributeError):
            log.warning('Callback no longer exists.')

    def unregister_ai_callback(self, callback, channel_name):
        try:
            s = self._get_channel_slice('hw_ai', channel_name)
            self._callbacks['ai'].remove((channel_name, s, callback))
        except (KeyError, AttributeError):
            log.warning('Callback no longer exists.')

    def unregister_di_callback(self, callback, channel_name):
        s = self._get_channel_slice('hw_di', channel_name)
        self._callbacks['di'].remove((channel_name, s, callback))

    def unregister_et_callback(self, callback, channel_name):
        s = self._get_channel_slice('cd_task', channel_name)
        self._callbacks['et'].remove((channel_name, s, callback))

    def write_sw_ao(self, state):
        task = self._tasks['sw_ao']
        state = np.array(state).astype(np.double)
        result = ctypes.c_int32()
        mx.DAQmxWriteAnalogF64(task, 1, True, 0, mx.DAQmx_Val_GroupByChannel,
                               state, result, None)
        if result.value != 1:
            raise ValueError('Unable to update software-timed AO')
        task._current_state = state

    def write_sw_do(self, state):
        task = self._tasks['sw_do']
        state = np.asarray(state).astype(np.uint8)
        result = ctypes.c_int32()
        mx.DAQmxWriteDigitalLines(task, 1, True, 0, mx.DAQmx_Val_GroupByChannel,
                                  state, result, None)
        if result.value != 1:
            raise ValueError('Problem writing data to software-timed DO')
        task._current_state = state

    def set_sw_do(self, name, state):
        task = self._tasks['sw_do']
        i = task._names.index(name)
        new_state = task._current_state.copy()
        new_state[i] = state
        self.write_sw_do(new_state)

    def set_sw_ao(self, name, state):
        task = self._tasks['sw_ao']
        i = task._names.index(name)
        new_state = task._current_state.copy()
        new_state[i] = state
        self.write_sw_ao(new_state)

    def fire_sw_do(self, name, duration=0.1):
        # TODO - Store reference to timer so that we can eventually track the
        # state of different timers and cancel pending timers when necessary.
        self.set_sw_do(name, 1)
        timer = Timer(duration, lambda: self.set_sw_do(name, 0))
        timer.start()

    def _et_fired(self, line_index, change, event_time):
        for i, cb in self._callbacks.get('et', []):
            if i == line_index:
                cb(change, event_time)

    @halt_on_error
    def _hw_ai_callback(self, samples):
        samples /= self._tasks['hw_ai']._sf
        for channel_name, s, cb in self._callbacks.get('ai', []):
            cb(samples[s])

    def _hw_di_callback(self, samples):
        for i, cb in self._callbacks.get('di', []):
            cb(samples[i])

    def _get_hw_ao_samples(self, offset, samples):
        channels = self.get_channels('analog', 'output', 'hardware')
        data = np.empty((len(channels), samples), dtype=np.double)
        for channel, ch_data in zip(channels, data):
            channel.get_samples(offset, samples, out=ch_data)
        return data

    def get_space_available(self, name=None):
        # It doesn't matter what the output channel is. Space will be the same
        # for all.
        result = ctypes.c_uint32()
        mx.DAQmxGetWriteSpaceAvail(self._tasks['hw_ao'], result)
        return result.value

    def hw_ao_callback(self, samples):
        # Get the next set of samples to upload to the buffer
        try:
            with self.lock:
                log_ao.trace('#> HW AO callback. Acquired lock for engine %s', self.name)
                available_samples = self.get_space_available()
                if available_samples < samples:
                    log_ao.trace('Not enough samples available for writing')
                else:
                    data = self._get_hw_ao_samples(self.total_samples_written, samples)
                    self.write_hw_ao(data, self.total_samples_written, timeout=0)
        except Exception as e:
            offset = ctypes.c_int32()
            curr_write = ctypes.c_uint64()
            mx.DAQmxGetWriteOffset(self._tasks['hw_ao'], offset)
            mx.DAQmxGetWriteCurrWritePos(self._tasks['hw_ao'], curr_write)
            raise

    def update_hw_ao(self, name, offset):
        # Get the next set of samples to upload to the buffer. Ignore the
        # channel name because we need to update all channels simultaneously.
        if offset > self.total_samples_written:
            return

        available = self.get_space_available()
        samples = available - (offset - self.total_samples_written)
        if samples <= 0:
            return

        log_ao.info('Updating hw ao at %d with %d samples', offset, samples)
        data = self._get_hw_ao_samples(offset, samples)
        self.write_hw_ao(data, offset=offset, timeout=0)

    def update_hw_ao_multiple(self, offsets, names):
        # This is really simple to implement since we have to update all
        # channels at once. So, we just pick the minimum offset and let
        # `update_hw_ao` do the work.
        self.update_hw_ao(None, min(offsets))

    @halt_on_error
    def write_hw_ao(self, data, offset, timeout=1):
        # TODO: add a safety-check to make sure waveform doesn't exceed limits.
        # This is a recoverable error unless the DAQmx API catches it instead.

        # Due to historical limitations in the DAQmx API, the write offset is a
        # signed 32-bit integer. For long-running applications, we will have an
        # overflow if we attempt to set the offset relative to the first sample
        # written. Therefore, we compute the write offset relative to the last
        # sample written (for requested offsets it should be negative).
        try:
            result = ctypes.c_int32()
            task = self._tasks['hw_ao']
            relative_offset = offset - self.total_samples_written
            mx.DAQmxSetWriteOffset(task, relative_offset)
            mx.DAQmxWriteAnalogF64(task, data.shape[-1], False, timeout, mx.DAQmx_Val_GroupByChannel, data.astype(np.float64), result, None)
            mx.DAQmxSetWriteOffset(task, 0)

            # Calculate total samples written
            self.total_samples_written = self.total_samples_written + relative_offset + data.shape[-1]

        except Exception as e:
            # If we log on every call, the logfile will get quite verbose.
            # Let's only log this information on an Exception.
            log_ao.info('Failed to write %r samples starting at %r', data.shape, offset)
            log_ao.info(' * Offset is %d samples relative to current write position %d', relative_offset, self.total_samples_written)
            log_ao.info(' * Current read position is %d and onboard buffer size is %d', self.ao_sample_clock(), task._onboard_buffer_size)
            raise

    def get_ts(self):
        try:
            return self.sample_time()
        except Exception as e:
            log.exception(e)
            return np.nan

    def start(self):
        if not self._configured:
            log.debug('Tasks were not configured yet')
            self.configure()

        log.debug('Reserving NIDAQmx task resources')
        for task in self._tasks.values():
            mx.DAQmxTaskControl(task, mx.DAQmx_Val_Task_Commit)

        if 'hw_ao' in self._tasks:
            log.debug('Calling HW ao callback before starting tasks')
            samples = self.get_space_available()
            self.hw_ao_callback(samples)

        log.debug('Starting NIDAQmx tasks')
        for task in self._tasks.values():
            log.debug('Starting task {}'.format(task._name))
            mx.DAQmxStartTask(task)

    def stop(self):
        # TODO: I would love to be able to stop a task and keep it in memory
        # without having to restart; however, this will require some thought as
        # to the optimal way to do this. For now, we just clear everything.
        # Configuration is generally fairly quick.
        if not self._configured:
            return
        log.debug('Stopping engine')
        for task in self._tasks.values():
            mx.DAQmxClearTask(task)
        self._callbacks = {}
        self._configured = False

    def ai_sample_clock(self):
        task = self._tasks['hw_ai']
        result = ctypes.c_uint64()
        mx.DAQmxGetReadTotalSampPerChanAcquired(task, result)
        return result.value

    def ai_sample_time(self):
        return self.ai_sample_clock()/self.ai_fs

    def ao_sample_clock(self):
        try:
            task = self._tasks['hw_ao']
            result = ctypes.c_uint64()
            mx.DAQmxGetWriteTotalSampPerChanGenerated(task, result)
            return result.value
        except:
            return 0

    def ao_sample_time(self):
        return self.ao_sample_clock()/self.ao_fs

    def get_buffer_size(self, name):
        return self.hw_ao_buffer_size
コード例 #9
0
ファイル: nidaq.py プロジェクト: buranconsult/psiexperiment
class NIDAQCounterChannel(NIDAQGeneralMixin, CounterChannel):

    high_samples = d_(Int().tag(metadata=True))
    low_samples = d_(Int().tag(metadata=True))
    source_terminal = d_(Str().tag(metadata=True))
コード例 #10
0
class OccViewer(Control):
    """ A spin box widget which manipulates integer values.
    """
    #: A reference to the ProxySpinBox object.
    proxy = Typed(ProxyOccViewer)

    #: The minimum value for the spin box. Defaults to 0.
    position = d_(Tuple(Int(strict=False), default=(0, 0)))

    #: Display mode
    display_mode = d_(Enum('shaded', 'hlr', 'wireframe'))

    #: Selection mode
    selection_mode = d_(Enum('shape', 'neutral', 'face', 'edge', 'vertex'))

    #: Selected items
    selection = d_(List(), writable=False)

    #: View direction
    view_mode = d_(
        Enum('iso', 'top', 'bottom', 'left', 'right', 'front', 'rear'))

    #: Selection event
    #reset_view = d_(Event(),writable=False)

    #: Show tahedron
    trihedron_mode = d_(Enum('right-lower', 'disabled'))

    #: Background gradient
    background_gradient = d_(
        Tuple(Int(), default=(206, 215, 222, 128, 128, 128)))

    #: Display shadows
    shadows = d_(Bool(False))

    #: Display reflections
    reflections = d_(Bool(True))

    #: Enable antialiasing
    antialiasing = d_(Bool(True))

    #: View expands freely in width by default.
    hug_width = set_default('ignore')

    #: View expands freely in height by default.
    hug_height = set_default('ignore')

    #: Events
    #: Raise StopIteration to indicate handling should stop
    key_pressed = d_(Event(), writable=False)
    mouse_pressed = d_(Event(), writable=False)
    mouse_released = d_(Event(), writable=False)
    mouse_wheeled = d_(Event(), writable=False)
    mouse_moved = d_(Event(), writable=False)

    # -------------------------------------------------------------------------
    # Observers
    # -------------------------------------------------------------------------
    @observe('position', 'display_mode', 'view_mode', 'trihedron_mode',
             'selection_mode', 'background_gradient', 'double_buffer',
             'shadows', 'reflections', 'antialiasing')
    def _update_proxy(self, change):
        """ An observer which sends state change to the proxy.
        """
        # The superclass handler implementation is sufficient.
        super(OccViewer, self)._update_proxy(change)
コード例 #11
0
ファイル: plugin.py プロジェクト: ylwb/declaracad
class ViewerProcess(ProcessLineReceiver):
    #: Window id obtained after starting the process
    window_id = Int()

    #: Process handle
    process = Instance(object)

    #: Reference to the plugin
    plugin = ForwardInstance(lambda: ViewerPlugin)

    #: Document
    document = ForwardInstance(document_type)

    #: Rendering error
    errors = Str()

    #: Process terminated intentionally
    terminated = Bool(False)

    #: Count restarts so we can detect issues with startup s
    restarts = Int()

    #: Max number it will attempt to restart
    max_retries = Int(10)

    #: ID count
    _id = Int()

    #: Holds responses temporarily
    _responses = Dict()

    #: Seconds to ping
    _ping_rate = Int(40)

    #: Capture stderr separately
    err_to_out = set_default(False)

    def redraw(self):
        if self.document:
            # Trigger a reload
            self.document.version += 1
        else:
            self.set_version(self._id)

    @observe('document', 'document.version')
    def _update_document(self, change):
        doc = self.document
        if doc is None:
            self.set_filename('-')
        else:
            self.set_filename(doc.name)
            self.set_version(doc.version)

    def send_message(self, method, *args, **kwargs):
        # Defer until it's ready
        if not self.transport or not self.window_id:
            #log.debug('renderer | message not ready deferring')
            timed_call(1000, self.send_message, method, *args, **kwargs)
            return
        _id = kwargs.pop('_id')
        _silent = kwargs.pop('_silent', False)

        request = {
            'jsonrpc': '2.0',
            'method': method,
            'params': args or kwargs
        }
        if _id is not None:
            request['id'] = _id
        if not _silent:
            log.debug(f'renderer | sent | {request}')
        encoded_msg = jsonpickle.dumps(request).encode() + b'\r\n'
        deferred_call(self.transport.write, encoded_msg)

    async def start(self):
        atexit.register(self.terminate)
        cmd = [sys.executable]
        if not sys.executable.endswith('declaracad'):
            cmd.extend(['-m', 'declaracad'])
        cmd.extend(['view', '-', '-f'])
        loop = asyncio.get_event_loop()
        self.process = await loop.subprocess_exec(lambda: self, *cmd)
        return self.process

    def restart(self):
        self.window_id = 0
        self.restarts += 1

        # TODO: 100 is probably excessive
        if self.restarts > self.max_retries:
            plugin = self.plugin
            plugin.workbench.message_critical(
                "Viewer failed to start",
                "Could not get the viewer to start after several attempts.")

            raise RuntimeError(
                "renderer | Failed to successfully start renderer aborting!")

        log.debug(f"Attempting to restart viewer {self.process}")
        deferred_call(self.start)

    def connection_made(self, transport):
        super().connection_made(transport)
        self.schedule_ping()
        self.terminated = False

    def data_received(self, data):
        line = data.decode()
        try:
            response = jsonpickle.loads(line)
            # log.debug(f"viewer | resp | {response}")
        except Exception as e:
            log.debug(f"viewer | out | {line.rstrip()}")
            response = {}

        doc = self.document

        if not isinstance(response, dict):
            log.debug(f"viewer | out | {response.rstrip()}")
            return

        #: Special case for startup
        response_id = response.get('id')
        if response_id == 'window_id':
            self.window_id = response['result']
            self.restarts = 0  # Clear the restart count
            return
        elif response_id == 'keep_alive':
            return
        elif response_id == 'invoke_command':
            command_id = response.get('command_id')
            parameters = response.get('parameters', {})
            log.debug(f"viewer | out | {command_id}({parameters})")
            self.plugin.workbench.invoke_command(command_id, parameters)
        elif response_id == 'render_error':
            if doc:
                doc.errors.extend(response['error']['message'].split("\n"))
            return
        elif response_id == 'render_success':
            if doc:
                doc.errors = []
            return
        elif response_id == 'capture_output':
            # Script output capture it
            if doc:
                doc.output = response['result'].split("\n")
            return
        elif response_id == 'shape_selection':
            #: TODO: Do something with this?
            if doc:
                doc.output.append(str(response['result']))
            return
        elif response_id is not None:
            # Lookup the deferred object that should be stored for this id
            # when it is called and invoke the callback or errback based on the
            # result
            d = self._responses.get(response_id)
            if d is not None:
                del self._responses[response_id]
                try:
                    error = response.get('error')
                    if error is not None:
                        if doc:
                            doc.errors.extend(
                                error.get('message', '').split("\n"))
                        d.add_done_callback(error)
                    else:
                        d.add_done_callback(response.get('result'))
                    return
                except Exception as e:
                    log.warning("RPC response not properly handled!")
                    log.exception(e)

            else:
                log.warning("Got unexpected reply")
            # else we got a response from something else, possibly an error?
        if 'error' in response and doc:
            doc.errors.extend(response['error'].get('message', '').split("\n"))
            doc.output.append(line)
        elif 'message' in response and doc:
            doc.output.extend(response['message'].split("\n"))
        elif doc:
            # Append to output
            doc.output.extend(line.split("\n"))

    def err_received(self, data):
        """ Catch and log error output attempting to decode it

        """
        for line in data.split(b"\n"):
            if not line:
                continue
            if line.startswith(b"QWidget::") or line.startswith(b"QPainter::"):
                continue
            try:
                line = line.decode()
                log.debug(f"render | err | {line}")
                if self.document:
                    self.document.errors.append(line)
            except Exception as e:
                log.debug(f"render | err | {line}")

    def process_exited(self, reason=None):
        log.warning(f"renderer | process ended: {reason}")
        if not self.terminated:
            # Clear the filename on crash so it works when reset
            self.restart()
        log.warning("renderer | stdout closed")

    def terminate(self):
        super(ViewerProcess, self).terminate()
        self.terminated = True

    def schedule_ping(self):
        """ Ping perioidcally so the process stays awake """
        if self.terminated:
            return
        # Ping the viewer to tell it to keep alive
        self.send_message("ping", _id="keep_alive", _silent=True)
        timed_call(self._ping_rate * 1000, self.schedule_ping)

    def __getattr__(self, name):
        """ Proxy all calls not defined here to the remote viewer.

        This makes doing `setattr(renderer, attr, value)` get passed to the
        remote viewer.

        """
        if name.startswith('set_'):

            def remote_viewer_call(*args, **kwargs):
                d = asyncio.Future()
                self._id += 1
                kwargs['_id'] = self._id
                self._responses[self._id] = d
                self.send_message(name, *args, **kwargs)
                return d

            return remote_viewer_call
        raise AttributeError("No attribute %s" % name)
コード例 #12
0
class TooltipCalibrationController(LiveCalibrationController):

    is_ready = Bool(False)

    result_count = Int(128)
    errors = Typed(np.ndarray)
    max_error = Float(0.0)
    initial_error = Float(-1)

    last_result = Value(None)

    results_txt = Value()
    progress_bar = Value()

    def setupController(self, active_widgets=None):
        super(TooltipCalibrationController,
              self).setupController(active_widgets=active_widgets)
        if active_widgets is not None:
            w = active_widgets[0]
            self.results_txt = w.find('results_txt')
            self.progress_bar = w.find('progress_bar')

        if self.autocomplete_maxerror_str != "":
            self.max_error = float(self.autocomplete_maxerror_str)

        # needs to match the SRG !!
        self.sync_source = 'calib_tooltip'
        self.required_sinks = [
            'calib_tooltip',
        ]

        # setup a errors buffer
        self.errors = np.array([np.nan] * self.result_count, dtype=np.double)

        if self.facade is not None:
            self.facade.observe("is_loaded", self.connector_setup)

    def connector_setup(self, change):
        if change['value'] and self.verify_connector():
            self.connector.setup(self.facade.instance)
            self.connector.observe(self.sync_source, self.handle_data)
            self.is_ready = True

    def handle_data(self, c):
        if self.connector.calib_tooltip is not None:
            tc = self.connector.calib_tooltip.get()
            self.results_txt.text = "Result:\n%s" % str(tc)

            if self.last_result is not None:
                error = norm(tc - self.last_result)
                self.errors[0] = error
                # implements simple ringbuffer
                self.errors = np.roll(self.errors, 1)

                if self.initial_error == -1:
                    self.initial_error = error

            self.last_result = tc

            # update progress bar
            if self.initial_error != -1:
                p = error / (self.initial_error - self.max_error)
                pv = int(np.sqrt(1 - max(0, min(p, 1))) * 100)
                if pv > self.progress_bar.value:
                    self.progress_bar.value = pv

            # check if the minimum of self.result_count results have been received
            if not np.isnan(np.sum(self.errors)):

                if np.all(self.errors < self.max_error):
                    log.info(
                        "Tooltip Calibration: Results are satisfactory (<%s) min: %s max: %s"
                        % (self.max_error, np.min(
                            self.errors), np.max(self.errors)))
                    self.result_ok = True
                    self.progress_bar.value = 100
                    if self.autocomplete_enable:
                        self.stopCalibration()
コード例 #13
0
ファイル: plugin.py プロジェクト: yangbiaocn/inkcut
class JobPlugin(Plugin):

    #: Units
    units = Enum(*unit_conversions.keys()).tag(config=True)

    #: Available materials
    materials = List(Material).tag(config=True)

    #: Current material
    material = Instance(Material, ()).tag(config=True)

    #: Previous jobs
    jobs = List(Job).tag(config=True)

    #: Current job
    job = Instance(Job).tag(config=True)

    #: Recently open paths
    recent_documents = List(Str()).tag(config=True)

    #: Number of recent documents
    recent_document_limit = Int(10).tag(config=True)
    saved_jobs_limit = Int(100).tag(config=True)

    #: Timeout for optimizing paths
    optimizer_timeout = Float(10, strict=False).tag(config=True)

    def _default_job(self):
        return Job(material=self.material)

    def _default_units(self):
        return 'in'

    # -------------------------------------------------------------------------
    # Plugin API
    # -------------------------------------------------------------------------
    def start(self):
        """ Register the plugins this plugin depends on

        """
        #: Now load state
        super(JobPlugin, self).start()

        #: If we loaded from state, refresh
        if self.job.document:
            self.refresh_preview()

        self.init_recent_documents_menu()

    # -------------------------------------------------------------------------
    # Job API
    # -------------------------------------------------------------------------
    def request_approval(self, job):
        """ Request approval to start a job. This will set the job.info.status
        to either 'approved' or 'cancelled'.

        """
        ui = self.workbench.get_plugin('enaml.workbench.ui')
        with enaml.imports():
            from .dialogs import JobApprovalDialog
        JobApprovalDialog(ui.window, plugin=self, job=job).exec_()

    def refresh_preview(self):
        """ Refresh the preview. Other plugins can request this """
        self._refresh_preview({})

    def open_document(self, path, nodes=None):
        """ Set the job.document if it is empty, otherwise close and create
        a  new Job instance.

        """
        if path == '-':
            log.debug("Opening document from stdin...")
        elif not os.path.exists(path):
            raise JobError("Cannot open %s, it does not exist!" % path)
        elif not os.path.isfile(path):
            raise JobError("Cannot open %s, it is not a file!" % path)

        # Close any old docs
        self.close_document()

        log.info("Opening {doc}".format(doc=path))
        try:
            self.job.document_kwargs = dict(ids=nodes)
            self.job.document = path
        except ValueError as e:
            #: Wrap in a JobError
            raise JobError(e)

        # Update recent documents
        if path != '-':
            docs = self.recent_documents[:]
            # Remove and re-ad to make it most recent
            if path in docs:
                docs.remove(path)
            docs.append(path)

            # Keep limit to 10
            if len(docs) > 10:
                docs.pop(0)

            self.recent_documents = docs

    def save_document(self):
        # Copy so the ui's update
        job = self.job
        jobs = self.jobs[:]
        if job in jobs:
            # Save a copy or any changes will update the copy as well
            job = job.clone()
        jobs.append(job)

        # Limit size
        if len(jobs) > self.saved_jobs_limit:
            jobs.pop(0)

        self.jobs = jobs

    def close_document(self):
        """ If the job currently has a "document" add this to the jobs list
        and create a new Job instance. Otherwise no job is open so do nothing.

        """
        if not self.job.document:
            return

        log.info("Closing {doc}".format(doc=self.job.document))
        # Create a new default job
        self.job = self._default_job()

    @observe('job.material')
    def _observe_material(self, change):
        """ Keep the job material and plugin material in sync.

        """
        m = self.material
        job = self.job
        if job.material != m:
            job.material = m

    @observe('job', 'job.model', 'job.material', 'material.size',
             'material.padding')
    def _refresh_preview(self, change):
        """ Redraw the preview on the screen

        """
        log.info(change)
        view_items = []

        #: Transform used by the view
        preview_plugin = self.workbench.get_plugin('inkcut.preview')
        job = self.job
        plot = preview_plugin.preview
        t = preview_plugin.transform

        #: Draw the device
        plugin = self.workbench.get_plugin('inkcut.device')
        device = plugin.device

        #: Apply the final output transforms from the device
        transform = device.transform if device else lambda p: p

        if device and device.area:
            area = device.area
            view_items.append(
                dict(path=transform(device.area.path * t),
                     pen=plot.pen_device,
                     skip_autorange=True)  #(False, [area.size[0], 0]))
            )

        #: The model is only set when a document is open and has no errors
        if job.model:
            view_items.extend([
                dict(path=transform(job.move_path), pen=plot.pen_up),
                dict(path=transform(job.cut_path), pen=plot.pen_down)
            ])
            #: TODO: This
            #if self.show_offset_path:
            #    view_items.append(PainterPathPlotItem(
            # self.job.offset_path,pen=self.pen_offset))
        if job.material:
            # Also observe any change to job.media and job.device
            view_items.extend([
                dict(path=transform(job.material.path * t),
                     pen=plot.pen_media,
                     skip_autorange=([0, job.size[0]], [0, job.size[1]])),
                dict(path=transform(job.material.padding_path * t),
                     pen=plot.pen_media_padding,
                     skip_autorange=True)
            ])

        #: Update the plot
        preview_plugin.set_preview(*view_items)

        #: Save config
        self.save()

    # -------------------------------------------------------------------------
    # Utilities
    # -------------------------------------------------------------------------

    def init_recent_documents_menu(self):
        """ Insert the `RecentDocumentsMenu` into the Menu declaration that
        automatically updates the recent document menu links.

        """
        recent_menu = self.get_recent_menu()
        if recent_menu is None:
            return
        for c in recent_menu.children:
            if isinstance(c, RecentDocumentsMenu):
                return  # Already added
        documents_menu = RecentDocumentsMenu(plugin=self, parent=recent_menu)
        documents_menu.initialize()

    def get_recent_menu(self):
        """ Get the recent menu item WorkbenchMenu """
        ui = self.workbench.get_plugin('enaml.workbench.ui')
        window_model = ui._model
        if not window_model:
            return
        for menu in window_model.menus:
            if menu.item.path == '/file':
                for c in menu.children:
                    if isinstance(c, WorkbenchMenu):
                        if c.item.path == '/file/recent/':
                            return c
コード例 #14
0
class Extension(Declarative):
    """ A declarative class which represents a plugin extension.

    An Extension must be declared as a child of a PluginManifest.

    """
    #: The globally unique identifier for the extension.
    id = d_(Unicode())

    #: The fully qualified id of the target extension point.
    point = d_(Unicode())

    #: An optional rank to use for order the extension among others.
    rank = d_(Int())

    #: A callable which will create the implementation object for the
    #: extension point. The call signature and return type are defined
    #: by the extension point plugin which invokes the factory.
    factory = d_(Callable())

    #: An optional description of the extension.
    description = d_(Unicode())

    @property
    def plugin_id(self):
        """ Get the plugin id from the parent plugin manifest.

        """
        return self.parent.id

    @property
    def qualified_id(self):
        """ Get the fully qualified extension identifer.

        """
        this_id = self.id
        if u'.' in this_id:
            return this_id
        return u'%s.%s' % (self.plugin_id, this_id)

    def get_child(self, kind, reverse=False):
        """ Find a child by the given type.

        Parameters
        ----------
        kind : type
            The declarative type of the child of interest.

        reverse : bool, optional
            Whether to search in reversed order. The default is False.

        Returns
        -------
        result : child or None
            The first child found of the requested type.

        """
        it = reversed if reverse else iter
        for child in it(self.children):
            if isinstance(child, kind):
                return child
        return None

    def get_children(self, kind):
        """ Get all the children of the given type.

        Parameters
        ----------
        kind : type
            The declarative type of the children of interest.

        Returns
        -------
        result : list
            The list of children of the request type.

        """
        return [c for c in self.children if isinstance(c, kind)]
コード例 #15
0
ファイル: qt_tool_bar.py プロジェクト: ylwb/enaml
class QtToolBar(QtConstraintsWidget, ProxyToolBar):
    """ A Qt implementation of an Enaml ToolBar.

    """
    #: A reference to the widget created by the proxy.
    widget = Typed(QCustomToolBar)

    #: Cyclic notification guard. This a bitfield of multiple guards.
    _guard = Int(0)

    #--------------------------------------------------------------------------
    # Initialization API
    #--------------------------------------------------------------------------
    def create_widget(self):
        """ Create the QCustomToolBar widget.

        """
        self.widget = QCustomToolBar(self.parent_widget())

    def init_widget(self):
        """ Initialize the tool bar widget.

        """
        super(QtToolBar, self).init_widget()
        d = self.declaration
        self.set_button_style(d.button_style)
        self.set_movable(d.movable)
        self.set_floatable(d.floatable)
        self.set_floating(d.floating)
        self.set_dock_area(d.dock_area)
        self.set_allowed_dock_areas(d.allowed_dock_areas)
        self.set_orientation(d.orientation)
        widget = self.widget
        widget.floated.connect(self.on_floated)
        widget.docked.connect(self.on_docked)

    def init_layout(self):
        """ Initialize the layout for the toolbar.

        """
        super(QtToolBar, self).init_layout()
        widget = self.widget
        for child in self.children():
            if isinstance(child, QtAction):
                widget.addAction(child.widget)
            elif isinstance(child, QtActionGroup):
                widget.addActions(child.actions())
            elif isinstance(child, QtWidget):
                widget.addAction(child.get_action(True))

    #--------------------------------------------------------------------------
    # Child Events
    #--------------------------------------------------------------------------
    def find_next_action(self, child):
        """ Locate the QAction object which logically follows the child.

        Parameters
        ----------
        child : QtToolkitObject
            The child object of interest.

        Returns
        -------
        result : QAction or None
            The QAction which logically follows the position of the
            child in the list of children. None will be returned if
            a relevant QAction is not found.

        """
        found = False
        for dchild in self.children():
            if found:
                if isinstance(dchild, QtAction):
                    return dchild.widget
                elif isinstance(dchild, QtActionGroup):
                    actions = dchild.actions()
                    if len(actions) > 0:
                        return actions[0]
                elif isinstance(dchild, QtWidget):
                    action = dchild.get_action(False)
                    if action is not None:
                        return action
            else:
                found = dchild is child

    def child_added(self, child):
        """ Handle the child added event for a QtToolBar.

        This handler will scan the children to find the proper point
        at which to insert the action.

        """
        super(QtToolBar, self).child_added(child)
        if isinstance(child, QtAction):
            before = self.find_next_action(child)
            self.widget.insertAction(before, child.widget)
        elif isinstance(child, QtActionGroup):
            before = self.find_next_action(child)
            self.widget.insertActions(before, child.actions())
        elif isinstance(child, QtWidget):
            before = self.find_next_action(child)
            self.widget.insertAction(before, child.get_action(True))

    def child_removed(self, child):
        """  Handle the child removed event for a QtToolBar.

        """
        super(QtToolBar, self).child_removed(child)
        if isinstance(child, QtAction):
            self.widget.removeAction(child.widget)
        elif isinstance(child, QtActionGroup):
            self.widget.removeActions(child.actions())
        elif isinstance(child, QtWidget):
            self.widget.removeAction(child.get_action(False))

    #--------------------------------------------------------------------------
    # Signal Handlers
    #--------------------------------------------------------------------------
    def on_floated(self):
        """ The signal handler for the 'floated' signal.

        """
        if not self._guard & FLOATED_GUARD:
            self._guard |= FLOATED_GUARD
            try:
                self.declaration.floating = True
            finally:
                self._guard &= ~FLOATED_GUARD

    def on_docked(self, area):
        """ The signal handler for the 'docked' signal.

        """
        if not self._guard & FLOATED_GUARD:
            self._guard |= FLOATED_GUARD
            try:
                self.declaration.floating = False
                self.declaration.dock_area = DOCK_AREAS_INV[area]
            finally:
                self._guard &= ~FLOATED_GUARD

    #--------------------------------------------------------------------------
    # ProxyToolBar API
    #--------------------------------------------------------------------------
    def set_button_style(self, style):
        """ Set the button style for the toolbar.

        """
        self.widget.setToolButtonStyle(BUTTON_STYLES[style])

    def set_movable(self, movable):
        """ Set the movable state on the underlying widget.

        """
        self.widget.setMovable(movable)

    def set_floatable(self, floatable):
        """ Set the floatable state on the underlying widget.

        """
        self.widget.setFloatable(floatable)

    def set_floating(self, floating):
        """ Set the floating staet on the underlying widget.

        """
        if not self._guard & FLOATED_GUARD:
            self._guard |= FLOATED_GUARD
            try:
                self.widget.setFloating(floating)
            finally:
                self._guard &= ~FLOATED_GUARD

    def set_dock_area(self, dock_area):
        """ Set the dock area on the underyling widget.

        """
        self.widget.setToolBarArea(DOCK_AREAS[dock_area])

    def set_allowed_dock_areas(self, dock_areas):
        """ Set the allowed dock areas on the underlying widget.

        """
        qt_areas = Qt.NoToolBarArea
        for area in dock_areas:
            qt_areas |= DOCK_AREAS[area]
        self.widget.setAllowedAreas(qt_areas)

    def set_orientation(self, orientation):
        """ Set the orientation of the underlying widget.

        """
        # If the tool bar is a child of a QMainWindow, then that window
        # will take control of setting its orientation and changes to
        # the orientation by the user must be ignored.
        widget = self.widget
        parent = widget.parent()
        if not isinstance(parent, QMainWindow):
            widget.setOrientation(ORIENTATIONS[orientation])
コード例 #16
0
ファイル: form.py プロジェクト: zxyfanta/enaml
class Form(Container):
    """ A Container subclass that arranges its children in two columns.

    The left column is typically Labels, but this is not a requirement.
    The right are the actual widgets for data entry. The children should
    be in alternating label/widget order. If there are an odd number
    of children, the last child will span both columns.

    The Form provides an extra constraint variable, 'midline', which
    is used as the alignment anchor for the columns.

    """
    #: The ConstraintVariable giving the midline along which the labels
    #: and widgets are aligned.
    midline = ConstraintMember()

    #: The spacing to place between the form rows, in pixels.
    row_spacing = d_(Int(10))

    #: The spacing to place between the form columns, in pixels.
    column_spacing = d_(Int(10))

    #--------------------------------------------------------------------------
    # Observers
    #--------------------------------------------------------------------------
    @observe('row_spacing', 'column_spacing')
    def _layout_invalidated(self, change):
        """ A private observer which invalidates the layout.

        """
        # The superclass handler is sufficient.
        super(Form, self)._layout_invalidated(change)

    #--------------------------------------------------------------------------
    # Layout Constraints
    #--------------------------------------------------------------------------
    def layout_constraints(self):
        """ Get the layout constraints for a Form.

        A Form supplies default constraints which will arrange the
        children in a two column layout. User defined 'constraints'
        will be added on top of the generated form constraints.

        This method cannot be overridden from Enaml syntax.

        """
        children = self.visible_widgets()
        labels = children[::2]
        widgets = children[1::2]
        n_labels = len(labels)
        n_widgets = len(widgets)
        if n_labels != n_widgets:
            if n_labels > n_widgets:
                odd_child = labels.pop()
            else:
                odd_child = widgets.pop()
        else:
            odd_child = None

        # Boundary flex spacer
        b_flx = spacer(0).flex()

        # Inter-column flex spacer
        c_flx = spacer(max(0, self.column_spacing)).flex()

        # Inter-row flex spacer
        r_flx = spacer(max(0, self.row_spacing)).flex()

        # Generate the row constraints and make the column stacks
        midline = self.midline
        top = self.contents_top
        left = self.contents_left
        right = self.contents_right
        constraints = self.constraints[:]
        column1 = [top, b_flx]
        column2 = [top, b_flx]
        push = constraints.append
        push_col1 = column1.append
        push_col2 = column2.append
        for label, widget in zip(labels, widgets):
            push((widget.left == midline) | 'strong')
            push(align('v_center', label, widget) | 'strong')
            push(horizontal(left, b_flx, label, c_flx, widget, b_flx, right))
            push_col1(label)
            push_col1(r_flx)
            push_col2(widget)
            push_col2(r_flx)

        # Handle the odd child and create the column constraints
        if odd_child is not None:
            push_col1(odd_child)
            push_col2(odd_child)
            push(horizontal(left, b_flx, odd_child, b_flx, right))
        else:
            column1.pop()
            column2.pop()
        bottom = self.contents_bottom
        push_col1(b_flx)
        push_col1(bottom)
        push_col2(b_flx)
        push_col2(bottom)
        push(vertical(*column1))
        push(vertical(*column2))

        return constraints
コード例 #17
0
class DeviceConfig(Model):
    """ The default device configuration. Custom devices may want to subclass
    this.

    """
    #: Time between each path command
    #: Time to wait between each step so we don't get
    #: way ahead of the cutter and fill up it's buffer
    step_time = Float(strict=False).tag(config=True)
    custom_rate = Float(-1, strict=False).tag(config=True)

    #: Distance between each command in user units
    #: this is effectively the resolution the software supplies
    step_size = Float(parse_unit('1mm'), strict=False).tag(config=True)

    #: Interpolate paths breaking them into small sections that
    #: can be sent. This allows pausing mid plot as many devices do not have
    #: a cancel command.
    interpolate = Bool(False).tag(config=True)

    #: How often the position will be updated in ms. Low power devices should
    #: set this to a high number like 2000 or 3000
    sample_rate = Int(100).tag(config=True)

    #: Final output rotation
    rotation = Enum(0, 90, -90).tag(config=True)

    #: Swap x and y axis
    swap_xy = Bool().tag(config=True)
    mirror_y = Bool().tag(config=True)
    mirror_x = Bool().tag(config=True)

    #: Final out scaling
    scale = ContainerList(Float(strict=False), default=[1, 1]).tag(config=True)

    #: Defines prescaling before conversion to a polygon
    quality_factor = Float(1, strict=False).tag(config=True)

    #: In cm/s
    speed = Float(4, strict=False).tag(config=True)
    speed_units = Enum('in/s', 'cm/s').tag(config=True)
    speed_enabled = Bool().tag(config=True)

    #: Force in g
    force = Float(40, strict=False).tag(config=True)
    force_units = Enum('g').tag(config=True)
    force_enabled = Bool().tag(config=True)

    #: Use absolute coordinates
    absolute = Bool().tag(config=True)

    #: Device output is spooled by an external service
    #: this will cause the job to run with no delays between commands
    spooled = Bool().tag(config=True)

    #: Use a virtual connection
    test_mode = Bool().tag(config=True)

    #: Init commands
    commands_before = Unicode().tag(config=True)
    commands_after = Unicode().tag(config=True)
    commands_connect = Unicode().tag(config=True)
    commands_disconnect = Unicode().tag(config=True)

    def _default_step_time(self):
        """ Determine the step time based on the device speed setting


        """
        #: Convert speed to px/s then to mm/s
        units = self.speed_units.split("/")[0]
        speed = parse_unit('%s%s' % (self.speed, units))
        speed = to_unit(speed, 'mm')
        if speed == 0:
            return 0

        #: No determine the time and convert to ms
        return max(0, round(1000 * self.step_size / speed))

    @observe('speed', 'speed_units', 'step_size')
    def _update_step_time(self, change):
        if change['type'] == 'update':
            self.step_time = self._default_step_time()
コード例 #18
0
ファイル: test_validators.py プロジェクト: mrakitin/atom
"""
import sys
import pytest
from atom.compat import long

from atom.api import (CAtom, Atom, Value, Bool, Int, Long, Range, Float,
                      FloatRange, Bytes, Str, Unicode, Enum, Callable, Coerced,
                      Tuple, List, ContainerList, Dict, Instance,
                      ForwardInstance, Typed, ForwardTyped, Subclass,
                      ForwardSubclass, Event)


@pytest.mark.parametrize("member, set_values, values, raising_values", [
    (Value(), ['a', 1, None], ['a', 1, None], []),
    (Bool(), [True, False], [True, False], 'r'),
    (Int(), [1], [1], [1.0, long(1)] if sys.version_info < (3, ) else [1.0]),
    (Int(strict=False), [1, 1.0, int(1)], 3 * [1], ['a']),
    (Long(strict=True), [long(1)], [long(1)], [1.0, 1] if sys.version_info <
     (3, ) else [0.1]),
    (Long(strict=False), [1, 1.0, int(1)], 3 * [1], ['a']),
    (Range(0, 2), [0, 2], [0, 2], [-1, 3]),
    (Range(2, 0), [0, 2], [0, 2], [-1, 3]),
    (Range(0), [0, 3], [0, 3], [-1]),
    (Range(high=2), [-1, 2], [-1, 2], [3]),
    (Float(), [1, int(1), 1.1], [1.0, 1.0, 1.1], ['']),
    (Float(strict=True), [1.1], [1.1], [1]),
    (FloatRange(0.0, 0.5), [0.0, 0.5], [0.0, 0.5], [-0.1, 0.6]),
    (FloatRange(0.5, 0.0), [0.0, 0.5], [0.0, 0.5], [-0.1, 0.6]),
    (FloatRange(0.0), [0.0, 0.6], [0.0, 0.6], [-0.1]),
    (FloatRange(high=0.5), [-0.3, 0.5], [-0.3, 0.5], [0.6]),
    (Bytes(), [b'a', u'a'], [b'a'] * 2, [1]),
コード例 #19
0
ファイル: plugin.py プロジェクト: Qcircuits/ecpy
class ErrorsPlugin(Plugin):
    """Plugin in charge of collecting of the errors.

    It will always log the errors, and will notify the user according to their
    type.

    """
    #: Errors for which a custom handler is registered.
    errors = List()

    def start(self):
        """Collect extensions.

        """
        checker = make_extension_validator(ErrorHandler, ('handle',))
        self._errors_handlers = ExtensionsCollector(workbench=self.workbench,
                                                    point=ERR_HANDLER_POINT,
                                                    ext_class=ErrorHandler,
                                                    validate_ext=checker)
        self._errors_handlers.start()
        self._update_errors(None)
        self._errors_handlers.observe('contributions', self._update_errors)

    def stop(self):
        """Stop the extension collector and clear the list of handlers.

        """
        self._errors_handlers.unobserve('contributions', self._update_errors)
        self._errors_handlers.stop()
        self.errors = []

    def signal(self, kind, **kwargs):
        """Signal an error occured in the system.

        Parameters
        ----------
        kind : unicode or None
            Kind of error which occurred. If a specific handler is found, it is
            used, otherwise the generic handling method is used.

        **kwargs :
            Arguments to pass to the error handler.

        """
        if self._gathering_counter:
            self._delayed[kind].append(kwargs)
            return

        widget = self._handle(kind, kwargs)

        if widget:
            # Show dialog in application modal mode
            dial = ErrorsDialog(errors={kind: widget})
            deferred_call(dial.exec_)

    def report(self, kind=None):
        """Show a widget summarizing all the errors.

        Parameters
        ----------
        kind : unicode, optional
            If specified only the error related to the specified kind will
            be reported.

        """
        handlers = self._errors_handlers.contributions
        errors = {}
        if kind:
            if kind not in handlers:
                msg = '''{} is not a registered error kind (it has no
                    associated handler)'''.format(kind)
                self.signal('error',
                            message=cleandoc(msg).replace('\n', ' '))
                return

            handlers = {kind: handlers[kind]}

        for kind in handlers:
            report = handlers[kind].report(self.workbench)
            if report:
                errors[kind] = report

        dial = ErrorsDialog(errors=errors)
        dial.exec_()

    def enter_error_gathering(self):
        """In gathering mode, error handling is differed till exiting the mode.

        """
        self._gathering_counter += 1

    def exit_error_gathering(self):
        """Upon leaving gathering mode, errors are handled.

        If error handling should lead to a window display, all widgets are
        collected and displayed in a single window.
        As the gathering mode can be requested many times, the errors are only
        handled when this method has been called as many times as its
        counterpart.

        """
        self._gathering_counter -= 1
        if self._gathering_counter < 1:
            # Make sure to also gather additional errors signal during errors
            # handling
            self._gathering_counter += 1

            # Handle all delayed errors
            errors = {}
            while self._delayed:
                delayed = self._delayed.copy()
                self._delayed.clear()
                for kind in delayed:
                    res = self._handle(kind, delayed[kind])
                    if res:
                        errors[kind] = res

            self._gathering_counter = 0

            if errors:
                dial = ErrorsDialog(errors=errors)
                deferred_call(dial.exec_)

    def install_excepthook(self):
        """Setup a global sys.excepthook for a nicer user experience.

        The error message suggest to the user to restart the app. In the future
        adding an automatic bug report system here would make sense.

        """
        def exception_handler(cls, value, traceback):
            """Log the error and signal to the user that it should restart the
            app.

            """
            msg = 'An uncaugt execption occured :\n%s : %s\nTraceback:\n%s'
            logger.error(msg % (cls.__name__, value,
                         ''.join(format_tb(traceback))))

            ui = self.workbench.get_plugin('enaml.workbench.ui')
            msg = ('An uncaught exception occured. This should not happen '
                   'and can have a number of side effects. It is hence '
                   'advised to save your work and restart the application.')
            warning(ui.window, 'Consider restart', fill(msg))

        import sys
        sys.excepthook = exception_handler

    # =========================================================================
    # --- Private API ---------------------------------------------------------
    # =========================================================================

    #: Contributed error handlers.
    _errors_handlers = Typed(ExtensionsCollector)

    #: Counter keeping track of how many times the gathering mode was entered
    #: the mode is exited only when the value reaches 0.
    _gathering_counter = Int()

    #: List of pairs (kind, kwargs) representing the error reports received
    #: while the gathering mode was active.
    _delayed = Typed(defaultdict, (list,))

    def _update_errors(self, change):
        """Update the list of supported errors when the registered handlers
        change

        """
        self.errors = list(self._errors_handlers.contributions)

    def _handle(self, kind, infos):
        """Dispatch error report to appropriate handler.

        """
        if kind in self._errors_handlers.contributions:
            handler = self._errors_handlers.contributions[kind]
            try:
                return handler.handle(self.workbench, infos)
            except Exception:
                try:
                    msg = ('Failed to handle %s error, infos were:\n' % kind +
                           pformat(infos) + '\nError was :\n' + format_exc())
                except Exception:
                    msg = ('Failed to handle %s error, and to ' % kind +
                           'format infos:\n' + format_exc())
                core = self.workbench.get_plugin('enaml.workbench.core')
                core.invoke_command('ecpy.app.errors.signal',
                                    dict(kind='error', message=msg))

        else:
            return self._handle_unknwon(kind, infos)

    def _handle_unknwon(self, kind, infos):
        """Generic handler for unregistered kind of errors.

        """
        try:
            # Delayed handling of errors
            if not isinstance(infos, dict):
                msg = '\n\n'.join((pformat(i) for i in infos))

            else:
                msg = pformat(infos)

        except Exception:
            msg = 'Failed to format the errors infos.\n' + format_exc()

        logger.debug('No handler found for "%s" kind of error:\n %s',
                     kind, msg)

        return UnknownErrorWidget(kind=kind, msg=msg)
コード例 #20
0
ファイル: test_validators.py プロジェクト: mrakitin/atom
    class EventValidationTest(Atom):

        ev_member = Event(Int())

        ev_type = Event(int)
コード例 #21
0
ファイル: spin_box.py プロジェクト: ylwb/enaml
class SpinBox(Control):
    """ A spin box widget which manipulates integer values.

    """
    #: The minimum value for the spin box. Defaults to 0.
    minimum = d_(Int(0))

    #: The maximum value for the spin box. Defaults to 100.
    maximum = d_(Int(100))

    #: The position value of the spin box. The value will be clipped to
    #: always fall between the minimum and maximum.
    value = d_(Int(0))

    #: An optional prefix to include in the displayed text.
    prefix = d_(Str())

    #: An optional suffix to include in the displayed text.
    suffix = d_(Str())

    #: Optional text to display when the spin box is at its minimum.
    #: This allows the developer to indicate to the user a special
    #: significance to the minimum value e.g. "Auto"
    special_value_text = d_(Str())

    #: The step size for the spin box. Defaults to 1.
    single_step = d_(Range(low=1))

    #: Whether or not the spin box is read-only. If True, the user
    #: will not be able to edit the values in the spin box, but they
    #: will still be able to copy the text to the clipboard.
    read_only = d_(Bool(False))

    #: Whether or not the spin box will wrap around at its extremes.
    #: Defaults to False.
    wrapping = d_(Bool(False))

    #: A spin box expands freely in width by default.
    hug_width = set_default('ignore')

    #: A reference to the ProxySpinBox object.
    proxy = Typed(ProxySpinBox)

    #--------------------------------------------------------------------------
    # Observers
    #--------------------------------------------------------------------------
    @observe('minimum', 'maximum', 'value', 'prefix', 'suffix',
             'special_value_text', 'single_step', 'read_only', 'wrapping')
    def _update_proxy(self, change):
        """ An observer which sends state change to the proxy.

        """
        # The superclass handler implementation is sufficient.
        super(SpinBox, self)._update_proxy(change)

    #--------------------------------------------------------------------------
    # PostSetAttr Handlers
    #--------------------------------------------------------------------------
    def _post_setattr_minimum(self, old, new):
        """ Post setattr the minimum value for the spin box.

        If the new minimum is greater than the current value or maximum,
        those values are adjusted up.

        """
        if new > self.maximum:
            self.maximum = new
        if new > self.value:
            self.value = new

    def _post_setattr_maximum(self, old, new):
        """ Post setattr the maximum value for the spin box.

        If the new maximum is less than the current value or the minimum,
        those values are adjusted down.

        """
        if new < self.minimum:
            self.minimum = new
        if new < self.value:
            self.value = new

    #--------------------------------------------------------------------------
    # Post Validation Handlers
    #--------------------------------------------------------------------------
    def _post_validate_value(self, old, new):
        """ Post validate the value for the spin box.

        The value is clipped to minimum and maximum bounds.

        """
        return max(self.minimum, min(new, self.maximum))
コード例 #22
0
ファイル: qt_slider.py プロジェクト: MShaffar19/enaml
class QtSlider(QtControl, ProxySlider):
    """ A Qt implementation of an Enaml ProxySlider.

    """
    #: A reference to the widget created by the proxy.
    widget = Typed(QSlider)

    #: Cyclic notification guard flags.
    _guard = Int(0)

    #--------------------------------------------------------------------------
    # Initialization API
    #--------------------------------------------------------------------------
    def create_widget(self):
        """ Create the underlying QSlider widget.

        """
        self.widget = QSlider(self.parent_widget())

    def init_widget(self):
        """ Initialize the underlying widget.

        """
        super(QtSlider, self).init_widget()
        d = self.declaration
        self.set_minimum(d.minimum)
        self.set_maximum(d.maximum)
        self.set_value(d.value)
        self.set_orientation(d.orientation)
        self.set_page_step(d.page_step)
        self.set_single_step(d.single_step)
        self.set_tick_interval(d.tick_interval)
        self.set_tick_position(d.tick_position)
        self.set_tracking(d.tracking)
        self.widget.valueChanged.connect(self.on_value_changed)

    #--------------------------------------------------------------------------
    # Signal Handlers
    #--------------------------------------------------------------------------
    def on_value_changed(self):
        """ Send the 'value_changed' action to the Enaml widget when the
        slider value has changed.

        """
        if not self._guard & VALUE_FLAG:
            self._guard |= VALUE_FLAG
            try:
                self.declaration.value = self.widget.value()
            finally:
                self._guard &= ~VALUE_FLAG

    #--------------------------------------------------------------------------
    # ProxySlider API
    #--------------------------------------------------------------------------
    def set_maximum(self, maximum):
        """ Set the maximum value of the underlying widget.

        """
        self.widget.setMaximum(maximum)

    def set_minimum(self, minimum):
        """ Set the minimum value of the underlying widget.

        """
        self.widget.setMinimum(minimum)

    def set_value(self, value):
        """ Set the value of the underlying widget.

        """
        if not self._guard & VALUE_FLAG:
            self._guard |= VALUE_FLAG
            try:
                self.widget.setValue(value)
            finally:
                self._guard &= ~VALUE_FLAG

    def set_page_step(self, page_step):
        """ Set the page step of the underlying widget.

        """
        self.widget.setPageStep(page_step)

    def set_single_step(self, single_step):
        """ Set the single step of the underlying widget.

        """
        self.widget.setSingleStep(single_step)

    def set_tick_interval(self, interval):
        """ Set the tick interval of the underlying widget.

        """
        self.widget.setTickInterval(interval)

    def set_tick_position(self, tick_position):
        """ Set the tick position of the underlying widget.

        """
        self.widget.setTickPosition(TICK_POSITION[tick_position])

    def set_orientation(self, orientation):
        """ Set the orientation of the underlying widget.

        """
        self.widget.setOrientation(ORIENTATION[orientation])

    def set_tracking(self, tracking):
        """ Set the tracking of the underlying widget.

        """
        self.widget.setTracking(tracking)
コード例 #23
0
class AbsoluteOrientationCalibrationController(LiveCalibrationController):

    is_ready = Bool(False)

    result_count = Int(128)
    errors_translation = Typed(np.ndarray)
    errors_rotation = Typed(np.ndarray)

    max_error_translation = Float(0.0)
    max_error_rotation = Float(0.0)

    initial_error_translation = Float(-1)
    initial_error_rotation = Float(-1)

    last_result = Value(None)

    results_txt = Value()
    progress_bar = Value()

    def setupController(self, active_widgets=None):
        super(AbsoluteOrientationCalibrationController,
              self).setupController(active_widgets=active_widgets)
        if active_widgets is not None:
            w = active_widgets[0]
            self.results_txt = w.find('results_txt')
            self.progress_bar = w.find('progress_bar')

        if self.autocomplete_maxerror_str != "":
            translation, rotation = [
                s.strip() for s in self.autocomplete_maxerror_str.split(",")
            ]
            self.max_error_translation = float(translation)
            self.max_error_rotation = float(rotation)

        # needs to match the SRG !!
        self.sync_source = 'calib_absolute_orientation'
        self.required_sinks = [
            'calib_absolute_orientation',
        ]

        # setup a errors buffer
        self.errors_translation = np.array([np.nan] * self.result_count,
                                           dtype=np.double)
        self.errors_rotation = np.array([np.nan] * self.result_count,
                                        dtype=np.double)

        if self.facade is not None:
            self.facade.observe("is_loaded", self.connector_setup)

    def connector_setup(self, change):
        if change['value'] and self.verify_connector():
            self.connector.setup(self.facade.instance)
            self.connector.observe(self.sync_source, self.handle_data)
            self.is_ready = True

    def handle_data(self, c):
        if self.connector.calib_absolute_orientation is not None:
            ao = self.connector.calib_absolute_orientation.get()
            self.results_txt.text = "Result:\n%s" % str(ao)

            if self.last_result is not None:
                t_error = norm(ao.translation() -
                               self.last_result.translation())
                self.errors_translation[0] = t_error
                # implement simple ringbuffer
                self.errors_translation = np.roll(self.errors_translation, 1)

                if self.initial_error_translation == -1:
                    self.initial_error_translation = t_error

                r_error = abs(
                    math.Quaternion(ao.rotation().inverted() *
                                    self.last_result.rotation()).angle())
                self.errors_rotation[0] = r_error
                # implement simple ringbuffer
                self.errors_rotation = np.roll(self.errors_rotation, 1)

                if self.initial_error_rotation == -1:
                    self.initial_error_rotation = r_error

            self.last_result = ao

            # update progress bar
            if self.initial_error_translation != -1 and self.initial_error_translation != -1:
                t_p = t_error / (self.initial_error_translation -
                                 self.max_error_translation)
                r_p = r_error / (self.initial_error_rotation -
                                 self.max_error_rotation)
                pv = int(np.sqrt(1 - max(0, min(max(t_p, r_p), 1))) * 100)
                if pv > self.progress_bar.value:
                    self.progress_bar.value = pv

            # check if the minimum of self.result_count results have been received
            if not np.isnan(np.sum(self.errors_translation)) and not np.isnan(
                    np.sum(self.errors_rotation)):
                if np.all(self.errors_translation < self.max_error_translation) and \
                        np.all(self.errors_rotation < self.max_error_rotation):
                    log.info(
                        "Absolute Orientation: Results are satisfactory for translation (<%s) min: %s max: %s and rotation (<%s) min: %s max %s"
                        %
                        (self.max_error_translation,
                         np.min(self.errors_translation),
                         np.max(self.errors_translation),
                         self.max_error_rotation, np.min(self.errors_rotation),
                         np.max(self.errors_rotation)))
                    self.result_ok = True
                    self.progress_bar.value = 100
                    if self.autocomplete_enable:
                        self.stopCalibration()
コード例 #24
0
class CounterHistogramAnalysis(AnalysisWithFigure):
    '''
    Takes in shot data, generates histograms, fits histograms,
    and then plots various attributes as a function of iteration along with histograms with fit overplotted.
    '''

    # =====================Fit Functions================= #
    def intersection(self, A0, A1, m0, m1, s0, s1):
        return (m1 * s0**2 - m0 * s1**2 -
                np.sqrt(s0**2 * s1**2 *
                        (m0**2 - 2 * m0 * m1 + m1**2 + 2 * np.log(A0 / A1) *
                         (s1**2 - s0**2)))) / (s0**2 - s1**2)

    def area(self, A0, A1, m0, m1, s0, s1):
        return np.sqrt(
            np.pi / 2) * (A0 * s0 + A0 * s0 * erf(m0 / np.sqrt(2) / s0) +
                          A1 * s1 + A1 * s1 * erf(m1 / np.sqrt(2) / s1))

        # Normed Overlap for arbitrary cut point
    def overlap(self, xc, A0, A1, m0, m1, s0, s1):
        err0 = A0 * np.sqrt(np.pi / 2) * s0 * (1 - erf(
            (xc - m0) / np.sqrt(2) / s0))
        err1 = A1 * np.sqrt(np.pi / 2) * s1 * (erf(
            (xc - m1) / np.sqrt(2) / s1) + erf(m1 / np.sqrt(2) / s1))
        return (err0 + err1) / self.area(A0, A1, m0, m1, s0, s1)

        # Relative Fraction in 1
    def frac(self, A0, A1, m0, m1, s0, s1):
        return 1 / (1 + A0 * s0 * (1 + erf(m0 / np.sqrt(2) / s0)) / A1 / s1 /
                    (1 + erf(m1 / np.sqrt(2) / s1)))

    def dblgauss(self, x, A0, A1, m0, m1, s0, s1):
        return A0 * np.exp(-(x - m0)**2 /
                           (2 * s0**2)) + A1 * np.exp(-(x - m1)**2 /
                                                      (2 * s1**2))

    # ==================================================== #

    update_lock = Bool(False)
    enable = Bool(False)
    hbins = Int(30)
    hist1 = None
    hist2 = None

    def __init__(self, name, experiment, description=''):
        super(CounterHistogramAnalysis, self).__init__(name, experiment,
                                                       description)
        self.properties += ['enable']

    def preExperiment(self, experimentResults):
        # self.hist_rec = np.recarray(1,)
        return

    def analyzeMeasurement(self, measurementResults, iterationResults,
                           experimentResults):
        return

    def analyzeIteration(self, iterationResults, experimentResults):
        if self.enable:
            histout = []  # amplitudes, edges
            # Overlap, fraction, cutoff
            fitout = np.recarray(2, [('overlap', float), ('fraction', float),
                                     ('cutoff', float)])
            optout = np.recarray(2,
                                 [('A0', float), ('A1', float), ('m0', float),
                                  ('m1', float), ('s0', float), ('s1', float)])
            shots = iterationResults['shotData'][()]
            # make shot number the primary axis
            shots = shots.reshape(-1, *shots.shape[2:]).swapaxes(0, 1)
            shots = shots[:, :, 0]  # pick out first roi only
            hbins = self.hbins
            if self.hbins < 0:
                hbins = np.arange(np.max(shots) + 1)
            for i in range(shots.shape[0]):
                gmix.fit(np.array([shots[i]]).transpose())
                h = np.histogram(shots[i], bins=hbins, normed=True)
                histout.append((h[1][:-1], h[0]))
                est = [
                    gmix.weights_.max() / 10,
                    gmix.weights_.min() / 10,
                    gmix.means_.min(),
                    gmix.means_.max(),
                    np.sqrt(gmix.means_.min()),
                    np.sqrt(gmix.means_.max())
                ]
                try:
                    popt, pcov = curve_fit(self.dblgauss, h[1][1:], h[0], est)
                    # popt=[A0,A1,m0,m1,s0,s1] : Absolute value
                    popt = np.abs(popt)
                    xc = self.intersection(*popt)
                    if np.isnan(xc):
                        print 'Bad Cut on Shot: {}'.format(i)
                        fitout[i] = np.nan, np.nan, np.nan
                        optout[i] = popt * np.nan
                    else:
                        fitout[i] = self.overlap(xc,
                                                 *popt), self.frac(*popt), xc
                        optout[i] = popt
                except (RuntimeError, RuntimeWarning, TypeError):
                    print 'Bad fit on Shot: {} '.format(i)
                    fitout[i] = np.nan, np.nan, np.nan
                    optout[i] = np.ones(6) * np.nan
            iterationResults['analysis/dblGaussPopt'] = optout
            iterationResults['analysis/dblGaussFit'] = fitout
            print histout
            iterationResults['analysis/histogram'] = np.array(histout,
                                                              dtype='uint32')
            self.updateFigure(iterationResults)
        return

    def updateFigure(self, iterationResults):
        if self.draw_fig:
            if self.enable:
                if not self.update_lock:
                    try:
                        self.update_lock = True

                        # There are two figures in an AnalysisWithFigure.  Draw to the offscreen figure.
                        fig = self.backFigure
                        # Clear figure.
                        fig.clf()
                        shots = iterationResults['shotData'][()]
                        # flatten sub-measurement dimension
                        # make shot number the primary axis (not measurement)
                        shots = shots.reshape(-1,
                                              *shots.shape[2:]).swapaxes(0, 1)
                        roi = 0
                        shots = shots[:, :, roi]  # pick out first roi only
                        popts = iterationResults['analysis/dblGaussPopt']
                        # fits = iterationResults['analysis/dblGaussFit']

                        # make one plot
                        for i in range(len(shots)):
                            ax = fig.add_subplot('{}1{}'.format(
                                len(shots), 1 + i))
                            hbins = self.hbins
                            if self.hbins < 0:
                                # use explicit bins
                                hbins = np.arange(np.max(shots[i, :]) + 1)
                            h = ax.hist(shots[i],
                                        bins=hbins,
                                        histtype='step',
                                        normed=True)
                            ax.plot(h[1][1:] - .5,
                                    self.dblgauss(h[1][1:], *popts[i]))
                            if i == 1:
                                ax.set_yscale('log', nonposy='clip')
                                ax.set_ylim(
                                    10**int(-np.log10(len(shots[i])) - 1), 1)
                            else:
                                ax.set_ylim(0, 1.05 * np.max(h[0]))

                        super(CounterHistogramAnalysis, self).updateFigure()

                    except:
                        logger.exception(
                            'Problem in CounterHistogramAnalysis.updateFigure().'
                        )
                    finally:
                        self.update_lock = False
コード例 #25
0
class ColormeshFormat(PlotFormat):
    clt = Typed(QuadMesh)
    zdata = Array()
    expand_XY = Bool(False).tag(
        desc="expands X and Y array by one so all zdata is plotted")

    include_colorbar = Bool(False)
    cmap = Enum(*colormap_names).tag(former="cmap")

    @plot_observe("cmap")
    def colormap_update(self, change):
        self.plot_set(change["name"])
        self.set_colorbar()

    @plot_observe("plotter.selected")
    def colorbar_update(self, change):
        if self.plotter.selected == self.plot_name:
            self.set_colorbar()

    def get_colorbar(self):
        if self.include_colorbar:
            if self.plotter.colorbar is None:
                self.plotter.colorbar = self.plotter.figure.colorbar(self.clt)
        return self.plotter.colorbar

    def set_colorbar(self):
        if self.include_colorbar:
            self.get_colorbar().update_bruteforce(self.clt)

    def set_clim(self, vmin, vmax):
        self.vmin = float(vmin)
        self.vmax = float(vmax)
        self.clt.set_clim(vmin, vmax)
        self.set_colorbar()

    @plot_observe("vmin", "vmax")
    def clim_update(self, change):
        self.set_clim(self.vmin, self.vmax)

    vmin = Float()
    vmax = Float()

    def _default_plot_type(self):
        return "colormap"

    h_line = Typed(Line2D)
    v_line = Typed(Line2D)

    cs_alpha = Float(1.0)
    cs_color = Enum(*colors_tuple[1:])
    cs_linewidth = Float(2.0)
    cs_linestyle = Enum('solid', 'dashed', 'dashdot', 'dotted')

    def cs_set(self, param):
        getattr(self.h_line, "set_" + param[3:])(getattr(self, param))
        getattr(self.v_line, "set_" + param[3:])(getattr(self, param))

    @observe("cs_alpha", "cs_linewidth", "cs_linestyle", "cs_color")
    def cs_update(self, change):
        if change["type"] == "update":
            self.cs_set(change["name"])
            if self.plotter.horiz_fig.canvas != None:
                self.plotter.horiz_fig.canvas.draw()
            if self.plotter.vert_fig.canvas != None:
                self.plotter.vert_fig.canvas.draw()

    def do_autolim(self):
        if self.plotter.auto_zlim:
            self.set_clim(nanmin(self.zdata), nanmax(self.zdata))
        else:
            self.set_clim(self.vmin, self.vmax)
        super(ColormeshFormat, self).do_autolim()

    def pcolormesh(self, *args, **kwargs):
        kwargs = process_kwargs(self, kwargs)
        self.remove_collection()
        if len(args) == 1:
            #if isinstance(args[0], tuple):
            #    self.zdata=zeros(args[0])
            #else:
            self.zdata = asanyarray(args[0])
            numRows, numCols = self.zdata.shape
            self.xdata = arange(numCols)
            self.ydata = arange(numRows)
        #elif len(args)==2:
        #    args=args+(zeros((len(self.xdata)-1, len(self.ydata)-1)),)
        #    self.xdata, self.ydata, self.zdata= [asanyarray(a) for a in args]
        elif len(args) == 3:
            self.xdata, self.ydata, self.zdata = [asanyarray(a) for a in args]
            if self.expand_XY:
                self.xdata = linspace(min(self.xdata), max(self.xdata),
                                      len(self.xdata) + 1)
                self.ydata = linspace(min(self.ydata), max(self.ydata),
                                      len(self.ydata) + 1)

        self.clt = self.plotter.axes.pcolormesh(self.xdata, self.ydata,
                                                self.zdata, **kwargs)
        self.do_autolim()
        #if self.plotter.auto_xlim:
        #    self.plotter.set_xlim(min(self.xdata), max(self.xdata))
        #if self.plotter.auto_ylim:
        #    self.plotter.set_ylim(min(self.ydata), max(self.ydata))
        #if self.plotter.auto_xlim:
        #    self.plotter.x_min=float(min((self.plotter.x_min, min(self.xdata))))
        #    self.plotter.x_min=float(max((self.plotter.x_min, max(self.xdata))))

        #self.plotter.set_xlim(min(self.xdata), max(self.xdata))
        #if self.plotter.auto_ylim:
        #    self.plotter.y_min=float(min((self.plotter.y_min, min(self.ydata))))
        #    self.plotter.y_min=float(max((self.plotter.y_min, max(self.ydata))))

    count = Int()

    def append_xy(self, z, index=None, axis=1):
        """appends points x and y if 2 args are passed and just y and len of xdata if one args is passed"""
        if index is None:
            index = self.count
            self.count += 1
        if axis == 1:
            self.zdata[:, index] = z
        else:
            self.zdata[index, :] = z
        self.clt.set_array(self.zdata.ravel())
        self.set_clim(amin(self.zdata), amax(self.zdata))
        #print dir(self.clt)
        #print dir(self.clt.get_axes())
        #print self.clt.get_axes().get_xlim()
        #print help(self.clt.get_axes().update_datalim)
        #print help(self.clt.set_axis)
        #self.clt.set_xdata(self.xdata)
        #self.pcolormesh(self.xdata, self.ydata, self.zdata)
        self.update_plot(update_legend=False)
        #fig=self.clt.get_figure()
        #if fig.canvas is not None:
        #    fig.canvas.draw()

    @transformation
    def colormap2horizlines(self):
        mlf = MultiLineFormat(plot_name=self.plot_name, plotter=self.plotter)
        mlf.multiline_plot([
            zip(self.xdata, self.zdata[n, :])
            for n in range(self.ydata.shape[0])
        ])
        mlf.ydata = self.ydata
        mlf.autocolor_set("color")
        return mlf

    @transformation
    def colormap2vertlines(self):
        mlf = MultiLineFormat(plot_name=self.plot_name, plotter=self.plotter)
        mlf.multiline_plot([
            zip(self.ydata, self.zdata[:, n])
            for n in range(self.xdata.shape[0])
        ])
        mlf.ydata = self.xdata
        mlf.autocolor_set("color")
        return mlf
コード例 #26
0
class CounterAnalysis(AnalysisWithFigure):
    counter_array = Member()
    binned_array = Member()
    meas_analysis_path = Str()
    meas_data_path = Str()
    iter_analysis_path = Str()
    update_lock = Bool(False)
    enable = Bool()
    drops = Int(3)
    bins = Int(25)
    shots = Int(2)
    ROIs = List([0])
    graph_roi = Int(0)

    def __init__(self, name, experiment, description=''):
        super(CounterAnalysis, self).__init__(name, experiment, description)
        self.meas_analysis_path = 'analysis/counter_data'
        self.meas_data_path = 'data/counter/data'
        self.iter_analysis_path = 'shotData'
        self.properties += ['enable', 'drops', 'bins', 'shots', 'graph_roi']

    def preIteration(self, iterationResults, experimentResults):
        self.counter_array = []
        self.binned_array = None

    def format_data(self, array):
        """Formats raw 2D counter data into the required 4D format.

        Formats raw 2D counter data with implicit stucture:
            [   # counter 0
                [ dropped_bins shot_time_series dropped_bins shot_time_series ... ],
                # counter 1
                [ dropped_bins shot_time_series dropped_bins shot_time_series ... ]
            ]
        into the 4D format expected by the subsequent analyses"
        [   # measurements, can have different lengths run-to-run
            [   # shots array, fixed size
                [   # roi list, shot 0
                    [ time_series_roi_0 ],
                    [ time_series_roi_1 ],
                    ...
                ],
                [   # roi list, shot 1
                    [ time_series_roi_0 ],
                    [ time_series_roi_1 ],
                    ...
                ],
                ...
            ],
            ...
        ]
        """
        rois, bins = array.shape[:2]
        bins_per_shot = self.drops + self.bins  # self.bins is data bins per shot
        # calculate the number of shots dynamically
        num_shots = int(bins / (bins_per_shot))
        # calculate the number of measurements contained in the raw data
        # there may be extra shots if we get branching implemented
        num_meas = num_shots // self.shots
        # build a mask for removing valid data
        shot_mask = ([False] * self.drops + [True] * self.bins)
        good_shots = self.shots * num_meas
        # mask for the roi
        ctr_mask = np.array(shot_mask * good_shots + 0 * shot_mask *
                            (num_shots - good_shots),
                            dtype='bool')
        # apply mask a reshape partially
        array = array[:, ctr_mask].reshape(
            (rois, num_meas, self.shots, self.bins))
        array = array.swapaxes(0, 1)  # swap rois and measurement axes
        array = array.swapaxes(1, 2)  # swap rois and shots axes
        return array

    def analyzeMeasurement(self, measurementResults, iterationResults,
                           experimentResults):
        if self.enable:
            # MFE 2018/01: this analysis has been generalized such that multiple sub measurements can occur
            # in the same traditional measurement
            array = measurementResults[self.meas_data_path][()]
            try:
                # package data into an array with shape (sub measurements, shots, counters, time series data)
                array = self.format_data(array)
                # flatten the sub_measurements by converting top level to normal list and concatentating
                self.counter_array += list(array)
            except ValueError:
                errmsg = "Error retrieving counter data.  Offending counter data shape: {}"
                logger.exception(errmsg.format(array.shape))
            except:
                logger.exception('Unhandled counter data exception')
            # write this cycle's data into hdf5 file so that the threshold analysis can read it
            # when multiple counter support is enabled, the ROIs parameter will hold the count
            # Note the constant 1 is for the roi column parameter, all counters get entered in a single row
            n_meas, n_shots, n_rois, bins = array.shape
            sum_array = array.sum(axis=3).reshape((n_meas, n_shots, n_rois, 1))
            measurementResults[self.meas_analysis_path] = sum_array
            # put the sum data in the expected format for display
            if self.binned_array is None:
                self.binned_array = [
                    sum_array.reshape((n_meas, n_shots, n_rois))
                ]
            else:
                self.binned_array = np.concatenate(
                    (self.binned_array,
                     [sum_array.reshape((n_meas, n_shots, n_rois))]))
        self.updateFigure()

    def analyzeIteration(self, iterationResults, experimentResults):
        if self.enable:
            # recalculate binned_array to get rid of cut data
            # iterationResults[self.iter_analysis_path] = self.binned_array
            meas = map(int, iterationResults['measurements'].keys())
            meas.sort()
            path = 'measurements/{}/' + self.meas_analysis_path
            try:
                res = np.array(
                    [iterationResults[path.format(m)] for m in meas])
            except KeyError:
                # I was having problem with the file maybe not being ready
                logger.warning(
                    "Issue reading hdf5 file. Waiting then repeating.")
                time.sleep(0.1)  # try again in a little
                res = []
                for m in meas:
                    try:
                        res.append(iterationResults[path.format(m)])
                    except KeyError:
                        msg = ("Reading from hdf5 file during measurement `{}`"
                               " failed.").format(m)
                        logger.exception(msg)
                res = np.array(res)
            total_meas = len(self.binned_array)
            # drop superfluous ROI_columns dimension
            self.binned_array = res.reshape(res.shape[:4])
            print('cut data: {}'.format(total_meas - len(self.binned_array)))
            iterationResults[self.iter_analysis_path] = self.binned_array
        return

    def updateFigure(self):
        if self.draw_fig:
            if self.enable:
                if not self.update_lock:
                    try:
                        self.update_lock = True

                        # There are two figures in an AnalysisWithFigure.  Draw to the offscreen figure.
                        fig = self.backFigure
                        # Clear figure.
                        fig.clf()

                        # make one plot
                        # Single shot
                        ax = fig.add_subplot(221)
                        # Average over all shots/iteration
                        ax2 = fig.add_subplot(222)
                        ptr = 0
                        ca = np.array(self.counter_array)
                        for s in range(self.shots):
                            xs = np.arange(ptr, ptr + self.bins)
                            ax.bar(xs, ca[-1, s, self.graph_roi])
                            ax2.bar(xs, ca[:, s, self.graph_roi].mean(0))
                            ptr += max(1.05 * self.bins, self.bins + 1)
                        ax.set_title('Measurement: {}'.format(len(ca)))
                        ax2.set_title('Iteration average')

                        # time series of sum data
                        ax = fig.add_subplot(223)
                        # histogram of sum data
                        ax2 = fig.add_subplot(224)
                        n_shots = self.binned_array.shape[2]
                        legends = []
                        for roi in range(self.binned_array.shape[3]):
                            for s in range(n_shots):
                                ax.plot(
                                    self.binned_array[:, :, s, roi].flatten(),
                                    '.')
                                # bins = max + 2 takes care of the case where all entries are 0, which casues
                                # an error in the plot
                                ax2.hist(
                                    self.binned_array[:, :, s, roi].flatten(),
                                    bins=np.arange(
                                        np.max(self.binned_array[:, :, s, roi].
                                               flatten()) + 2),
                                    histtype='step')
                                legends.append("c{}_s{}".format(roi, s))
                        ax.set_title('Binned Data')
                        ax2.legend(legends, fontsize='small', loc=0)
                        super(CounterAnalysis, self).updateFigure()

                    except:
                        logger.exception(
                            'Problem in CounterAnalysis.updateFigure()')
                    finally:
                        self.update_lock = False
コード例 #27
0
ファイル: models.py プロジェクト: yairvillarp/inkcut
class Job(Model):
    """ Create a plot depending on the properties set. Any property that is a
    traitlet will cause an update when the value is changed.

    """
    #: Material this job will be run on
    material = Instance(Material, ()).tag(config=True)

    #: Path to svg document this job parses
    document = Unicode().tag(config=True)

    #: Nodes to restrict
    document_kwargs = Dict().tag(config=True)

    #: Meta info a the job
    info = Instance(JobInfo, ()).tag(config=True)

    # Job properties used for generating the plot
    size = ContainerList(Float(), default=[1, 1])
    scale = ContainerList(Float(), default=[1, 1]).tag(config=True)
    auto_scale = Bool(False).tag(
        config=True, help="automatically scale if it's too big for the area")
    lock_scale = Bool(True).tag(
        config=True, help="automatically scale if it's too big for the area")

    mirror = ContainerList(Bool(), default=[False, False]).tag(config=True)
    align_center = ContainerList(Bool(), default=[False,
                                                  False]).tag(config=True)

    rotation = Float(0).tag(config=True)
    auto_rotate = Bool(False).tag(
        config=True, help="automatically rotate if it saves space")

    copies = Int(1).tag(config=True)
    auto_copies = Bool(False).tag(config=True, help="always use a full stack")
    copy_spacing = ContainerList(Float(), default=[10, 10]).tag(config=True)
    copy_weedline = Bool(False).tag(config=True)
    copy_weedline_padding = ContainerList(Float(),
                                          default=[10, 10, 10,
                                                   10]).tag(config=True)

    plot_weedline = Bool(False).tag(config=True)
    plot_weedline_padding = ContainerList(Float(),
                                          default=[10, 10, 10,
                                                   10]).tag(config=True)

    order = Enum(*sorted(ordering.REGISTRY.keys())).tag(config=True)

    def _default_order(self):
        return 'Normal'

    feed_to_end = Bool(False).tag(config=True)
    feed_after = Float(0).tag(config=True)

    stack_size = ContainerList(Int(), default=[0, 0])

    path = Instance(QtGui.QPainterPath)  # Original path
    model = Instance(QtGui.QPainterPath)  # Copy using job properties

    _blocked = Bool(False)  # block change events
    _desired_copies = Int(1)  # required for auto copies

    def __setstate__(self, *args, **kwargs):
        """ Ensure that when restoring from disk the material and info
        are not set to None. Ideally these would be defined as Typed but
        the material may be made extendable at some point.
        """
        super(Job, self).__setstate__(*args, **kwargs)
        if not self.info:
            self.info = JobInfo()
        if not self.material:
            self.material = Material()

    def _observe_document(self, change):
        """ Read the document from stdin """
        if change['type'] == 'update' and self.document == '-':
            #: Only load from stdin when explicitly changed to it (when doing
            #: open from the cli) otherwise when restoring state this hangs
            #: startup
            self.path = QtSvgDoc(sys.stdin, **self.document_kwargs)
        elif self.document and os.path.exists(self.document):
            self.path = QtSvgDoc(self.document, **self.document_kwargs)

    def _create_copy(self):
        """ Creates a copy of the original graphic applying the given
        transforms

        """
        bbox = self.path.boundingRect()

        # Create the base copy
        t = QtGui.QTransform()

        t.scale(
            self.scale[0] * (self.mirror[0] and -1 or 1),
            self.scale[1] * (self.mirror[1] and -1 or 1),
        )

        # Rotate about center
        if self.rotation != 0:
            c = bbox.center()
            t.translate(-c.x(), -c.y())
            t.rotate(self.rotation)
            t.translate(c.x(), c.y())

        # Apply transform
        path = self.path * t

        # Add weedline to copy
        if self.copy_weedline:
            self._add_weedline(path, self.copy_weedline_padding)

        # Apply ordering to path
        # this delegates to objects in the ordering module
        # TODO: Should this be done via plugins?
        OrderingHandler = ordering.REGISTRY.get(self.order)
        if OrderingHandler:
            path = OrderingHandler().order(self, path)

        # If it's too big we have to scale it
        w, h = path.boundingRect().width(), path.boundingRect().height()
        available_area = self.material.available_area

        #: This screws stuff up!
        if w > available_area.width() or h > available_area.height():

            # If it's too big an auto scale is enabled, resize it to fit
            if not self.auto_scale:
                raise JobError("Image is too large to fit on the material")
            sx, sy = 1, 1
            if w > available_area.width():
                sx = available_area.width() / w
            if h > available_area.height():
                sy = available_area.height() / h
            s = min(sx, sy)  # Fit to the smaller of the two
            path = self.path * QtGui.QTransform.fromScale(s, s)

        # Move to bottom left
        p = path.boundingRect().bottomRight()

        path = path * QtGui.QTransform.fromTranslate(-p.x(), -p.y())

        return path

    @contextmanager
    def events_suppressed(self):
        """ Block change events to prevent feedback loops

        """
        self._blocked = True
        try:
            yield
        finally:
            self._blocked = False

    @observe('path', 'scale', 'auto_scale', 'lock_scale', 'mirror',
             'align_center', 'rotation', 'auto_rotate', 'copies', 'order',
             'copy_spacing', 'copy_weedline', 'copy_weedline_padding',
             'plot_weedline', 'plot_weedline_padding', 'feed_to_end',
             'feed_after', 'material', 'material.size', 'material.padding',
             'auto_copies')
    def _job_changed(self, change):
        """ Recreate an instance of of the plot using the current settings

        """
        if self._blocked:
            return

        if change['name'] == 'copies':
            self._desired_copies = self.copies

        #try:
        model = QtGui.QPainterPath()

        if not self.path:
            return

        path = self._create_copy()

        # Update size
        bbox = path.boundingRect()
        self.size = [bbox.width(), bbox.height()]

        # Create copies
        c = 0
        points = self._copy_positions_iter(path)

        if self.auto_copies:
            self.stack_size = self._compute_stack_sizes(path)
            if self.stack_size[0]:
                copies_left = self.copies % self.stack_size[0]
                if copies_left:  # not a full stack
                    with self.events_suppressed():
                        self.copies = self._desired_copies
                        self.add_stack()

        while c < self.copies:
            x, y = next(points)
            model.addPath(path * QtGui.QTransform.fromTranslate(x, -y))
            c += 1

        # Create weedline
        if self.plot_weedline:
            self._add_weedline(model, self.plot_weedline_padding)

        # Move to 0,0
        bbox = model.boundingRect()
        p = bbox.bottomLeft()
        tx, ty = -p.x(), -p.y()

        # Center or set to padding
        tx += ((self.material.width() - bbox.width()) /
               2.0 if self.align_center[0] else self.material.padding_left)
        ty += (-(self.material.height() - bbox.height()) /
               2.0 if self.align_center[1] else -self.material.padding_bottom)

        t = QtGui.QTransform.fromTranslate(tx, ty)

        model = model * t

        end_point = (QtCore.QPointF(
            0, -self.feed_after + model.boundingRect().top())
                     if self.feed_to_end else QtCore.QPointF(0, 0))
        model.moveTo(end_point)

        # Set new model
        self.model = model  #.simplified()

        # Set device model
        #self.device_model = self.device.driver.prepare_job(self)
        #except:
        #    # Undo the change
        #    if 'oldvalue' in change:
        #        setattr(change['object'],change['name'],change['oldvalue'])
        #    raise
        #if not self.check_bounds(self.boundingRect(),self.available_area):
        #    raise JobError(
        #       "Plot outside of plotting area, increase the area"
        #       "or decrease the scale or decrease number of copies!")

    def _check_bounds(self, plot, area):
        """ Checks that the width and height of plot are less than the width
        and height of area

        """
        return plot.width() > area.width() or plot.height() > area.height()

    def _copy_positions_iter(self, path, axis=0):
        """ Generator that creates positions of points

        """
        other_axis = axis + 1 % 2
        p = [0, 0]

        bbox = path.boundingRect()
        d = (bbox.width(), bbox.height())
        pad = self.copy_spacing
        stack_size = self._compute_stack_sizes(path)

        while True:
            p[axis] = 0
            yield p  # Beginning of each row

            for i in range(stack_size[axis] - 1):
                p[axis] += d[axis] + pad[axis]
                yield p

            p[other_axis] += d[other_axis] + pad[other_axis]

    def _compute_stack_sizes(self, path):
        # Usable area
        material = self.material
        a = [material.width(), material.height()]
        a[0] -= material.padding[Padding.LEFT] + material.padding[
            Padding.RIGHT]
        a[1] -= material.padding[Padding.TOP] + material.padding[
            Padding.BOTTOM]

        # Clone includes weedline but not spacing
        bbox = path.boundingRect()
        size = [bbox.width(), bbox.height()]

        stack_size = [0, 0]
        p = [0, 0]
        for i in range(2):
            # Compute stack
            while (p[i] + size[i]) < a[i]:  # while another one fits
                stack_size[i] += 1
                p[i] += size[i] + self.copy_spacing[i]  # Add only to end

        self.stack_size = stack_size
        return stack_size

    def _add_weedline(self, path, padding):
        """ Adds a weedline to the path
        by creating a box around the path with the given padding

        """
        bbox = path.boundingRect()
        w, h = bbox.width(), bbox.height()

        tl = bbox.topLeft()
        x = tl.x() - padding[Padding.LEFT]
        y = tl.y() - padding[Padding.TOP]

        w += padding[Padding.LEFT] + padding[Padding.RIGHT]
        h += padding[Padding.TOP] + padding[Padding.BOTTOM]

        path.addRect(x, y, w, h)
        return path

    @property
    def state(self):
        pass

    @property
    def move_path(self):
        """ Returns the path the head moves when not cutting

        """
        # Compute the negative
        path = QtGui.QPainterPath()
        for i in range(self.model.elementCount()):
            e = self.model.elementAt(i)
            if e.isMoveTo():
                path.lineTo(e.x, e.y)
            else:
                path.moveTo(e.x, e.y)
        return path

    @property
    def cut_path(self):
        """ Returns path where it is cutting

        """
        return self.model

    #     def get_offset_path(self,device):
    #         """ Returns path where it is cutting """
    #         path = QtGui.QPainterPath()
    #         _p = QtCore.QPointF(0,0) # previous point
    #         step = 0.1
    #         for subpath in QtSvgDoc.toSubpathList(self.model):#.toSubpathPolygons():
    #             e = subpath.elementAt(0)
    #             path.moveTo(QtCore.QPointF(e.x,e.y))
    #             length = subpath.length()
    #             distance = 0
    #             while distance<=length:
    #                 t = subpath.percentAtLength(distance)
    #                 p = subpath.pointAtPercent(t)
    #                 a = subpath.angleAtPercent(t)+90
    #                 #path.moveTo(p)#QtCore.QPointF(x,y))
    #                 # TOOD: Do i need numpy here???
    #                 x = p.x()+np.multiply(self.device.blade_offset,np.sin(np.deg2rad(a)))
    #                 y = p.y()+np.multiply(self.device.blade_offset,np.cos(np.deg2rad(a)))
    #                 path.lineTo(QtCore.QPointF(x,y))
    #                 distance+=step
    #             #_p = p # update last
    #
    #         return path

    def add_stack(self):
        """ Add a complete stack or fill the row

        """
        copies_left = self.stack_size[0] - (self.copies % self.stack_size[0])
        if copies_left == 0:  # Add full stack
            self.copies = self.copies + self.stack_size[0]
        else:  # Fill stack
            self.copies = self.copies + copies_left

    def remove_stack(self):
        """ Remove a complete stack or the rest of the row

        """
        if self.copies <= self.stack_size[0]:
            self.copies = 1
            return

        copies_left = self.copies % self.stack_size[0]
        if copies_left == 0:  # Add full stack
            self.copies = self.copies - self.stack_size[0]
        else:  # Fill stack
            self.copies = self.copies - copies_left

    def clone(self):
        """ Return a cloned instance of this object

        """
        state = self.__getstate__()
        state.update({
            'material': Material(**self.material.__getstate__()),
            'info': JobInfo(**self.info.__getstate__()),
        })
        return Job(**state)
コード例 #28
0
class Lyzer(TA88_Fund):
    rd_hdf=Typed(TA88_Read)

    comment=Unicode().tag(read_only=True, spec="multiline")

    rt_atten=Float(60)

    rt_gain=Float(26*2)

    frequency=Array().tag(unit="GHz", plot=True, label="Frequency")
    yoko=Array().tag(unit="V", plot=True, label="Yoko")
    Magcom=Array().tag(private=True)

    probe_frq=Float().tag(unit="GHz", label="Probe frequency", read_only=True)
    probe_pwr=Float().tag(label="Probe power", read_only=True, display_unit="dBm/mW")

    pind=Int()


    @tag_Property(display_unit="dB", plot=True)
    def MagdB(self):
        return self.Magcom[:, :]/dB#-mean(self.Magcom[:, 169:171], axis=1, keepdims=True)/dB

    @tag_Property(plot=True)
    def Phase(self):
        return angle(self.Magcom[:, :]-mean(self.Magcom[:, 990:1010], axis=1, keepdims=True))

    @tag_Property( plot=True)
    def MagAbs(self):
        #return absolute(self.Magcom[:, :])
        return absolute(self.Magcom[:, :]-mean(self.Magcom[:, 620:700], axis=1, keepdims=True))


    def _default_rd_hdf(self):
        return TA88_Read(main_file="Data_0309/S1A4_TA88_wide_f_lowpwr_overnight.hdf5")

    def read_data(self):
        with File(self.rd_hdf.file_path, 'r') as f:
            print f["Traces"].keys()
            self.comment=f.attrs["comment"]
            print f["Instrument config"].keys()
            self.probe_frq=f["Instrument config"]['Rohde&Schwarz Network Analyzer - IP: 169.254.107.192, RS VNA at localhost'].attrs["Start frequency"]
            self.probe_pwr=f["Instrument config"]['Rohde&Schwarz Network Analyzer - IP: 169.254.107.192, RS VNA at localhost'].attrs["Output power"]
            print f["Instrument config"]['Rohde&Schwarz Network Analyzer - IP: 169.254.107.192, RS VNA at localhost'].attrs
#
            print f["Data"]["Channel names"][:]
            Magvec=f["Traces"]["RS VNA - S21"]#[:]
            data=f["Data"]["Data"]
            print shape(data)
#
            self.yoko=data[:,0,0].astype(float64)
            fstart=f["Traces"]['RS VNA - S21_t0dt'][0][0]
            fstep=f["Traces"]['RS VNA - S21_t0dt'][0][1]
            print shape(Magvec)
            sm=shape(Magvec)[0]
            sy=shape(data)
            s=(sm, sy[0], 1)#sy[2])
            print s
            Magcom=Magvec[:,0, :]+1j*Magvec[:,1, :]

            Magcom=reshape(Magcom, s, order="F")
            self.frequency=linspace(fstart, fstart+fstep*(sm-1), sm)
            print shape(Magcom)
            self.Magcom=squeeze(Magcom)
        with File("/Users/thomasaref/Dropbox/Current stuff/Logbook/TA210715A45_cooldown270216/Data_0227/S4A4_TA88_wideSC1116unswitched.hdf5", "r") as f:
            Magvec=f["Traces"]["RS VNA - S21"]
            fstart=f["Traces"]['RS VNA - S21_t0dt'][0][0]
            fstep=f["Traces"]['RS VNA - S21_t0dt'][0][1]
            sm=shape(Magvec)[0]
            s=(sm, 1, 1)
            Magcom=Magvec[:,0,:]+1j*Magvec[:,1,:]
            Magcom=reshape(Magcom, s, order="F")
            frequency=linspace(fstart, fstart+fstep*(sm-1), sm)
            Magcom=squeeze(Magcom)
        return frequency, Magcom
コード例 #29
0
class Plotter(Atom):
     name=Unicode()
     title=Unicode("yoyoyoyoyo")
     xlabel=Unicode()
     ylabel=Unicode()

     xyfs=Dict()
     #pd=Typed(ArrayPlotData, ())
     plot= ForwardTyped(lambda: Plot)
     color_index=Int()
     clts=Dict()
     fig=Typed(Figure)
     axe=Typed(Axes)
     #clt=Typed(PolyCollection)
     plottables=Dict()

     overall_plot_type=Enum("XY plot", "img plot")
     value_scale=Enum('linear', 'log')
     index_scale=Enum('linear', 'log')

     #def _default_clt(self):
     #    return PolyCollection([((0,0), (0,0))], alpha=0.6, antialiased=True)#, rasterized=False, antialiased=False)
     
     def _default_axe(self):
         axe=self.fig.add_subplot(111)
         axe.autoscale_view(True)
         return axe
         
     def _default_fig(self):
         return Figure()

     def _observe_value_scale(self, change):
         if self.overall_plot_type=="XY plot":
             self.plot.value_scale=self.value_scale
             self.plot.request_redraw()

     def _observe_index_scale(self, change):
         if self.overall_plot_type=="XY plot":
             self.plot.index_scale=self.index_scale
             self.plot.request_redraw()

     def _default_plottables(self):
         return dict(plotted=[None])

     def _observe_title(self, change):
         self.axe.set_title(self.title)
         #self.plot.request_redraw()

     def _observe_xlabel(self, change):
         self.axe.set_xlabel(self.xlabel)
         #self.plot.x_axis.title=self.xlabel
         #self.plot.request_redraw()

     def _observe_ylabel(self, change):
         self.axe.set_ylabel(self.ylabel)
         #self.plot.y_axis.title=self.ylabel
         #self.plot.request_redraw()

     def _default_xyfs(self):
         xyf=AllXYFormat(plotter=self)
         return {"All":xyf}

     def delete_all_plots(self):
         for key in self.plot.plots.keys():
                self.plot.delplot(key)
         self.color_index=0

     def _save(self):
         global PlotGraphicsContext
         if PlotGraphicsContext==None:
             from chaco.plot_graphics_context import PlotGraphicsContext
         win_size = self.plot.outer_bounds
         plot_gc = PlotGraphicsContext(win_size)#, dpi=300)
         plot_gc.render_component(self.plot)
         plot_gc.save("image_test.png")

     def set_data(self, zname, zdata, zcolor):
         if zdata!=None:
            if zname not in self.clts: #plottables['plotted']:#self.pd.list_data():
                clt=PolyCollection(zdata, alpha=0.5, antialiased=True)#, rasterized=False, antialiased=False)
                clt.set_color(colorConverter.to_rgba(zcolor))                
                self.clts[zname]=clt
                self.axe.add_collection(self.clts[zname], autolim=True)
            else:                
                self.clts[zname].set_verts(zdata)

     def add_text(self, text, x, y, **kwargs):
         """adds text at data location x,y"""
         self.axe.text(x, y, text, **kwargs)
         
     def draw(self):
         if self.fig.canvas!=None:
             #trans = transforms.Affine2D().scale(self.fig.dpi/72.0)
             #self.clt.set_transform(trans)  # the points to pixels transform
             #self.clt.set_color(colors)
         
             #self.axe.autoscale_view(True)
             self.fig.canvas.draw()

     def set_xlim(self, xmin, xmax):
         self.axe.set_xlim(xmin, xmax)

     def set_ylim(self, ymin, ymax):
         self.axe.set_ylim(ymin, ymax)
         
     def get_data(self, zname, index=None, axis=0):
        data=[c.to_polygons() for c in self.clt.get_paths()]
        if index==None:
            return data
        if axis==0:
            return atleast_2d(data)[index, :]
        return atleast_2d(data)[:, index]
        
     def add_poly_plot_old(self, n, verts, cn="green", polyname=""):
         nxarray, nyarray = transpose(verts)
         xname=polyname+"x" + str(n)
         yname=polyname+"y" + str(n)
         self.pd.set_data(xname, nxarray, coord='x') #coord='x' is likely redundant or a metadata tag
         self.pd.set_data(yname, nyarray, coord='y')
         self.plot.plot((xname, yname),
                          type="polygon",
                          face_color=cn, #colors[nsides],
                          hittest_type="poly")[0]

     def add_poly_plot(self, n, verts, cn="green", polyname=""):
        #for n,p in enumerate(self.polylist):
            log_debug("drawing polygon #: {0}".format(n))
            #npoints = p.verts #n_gon(center=p, r=2, nsides=nsides)
            nxarray, nyarray = transpose(verts)
            self.pd.set_data("x" + str(n), nxarray)
            self.pd.set_data("y" + str(n), nyarray)
            log_debug("data set")            
            self.plot.plot(("x"+str(n), "y"+str(n)),
                          type="polygon",
                          face_color=cn, #colors[nsides],
                          hittest_type="poly"
                          )[0]
            log_debug("plot occured")

     def add_img_plot(self, zname, zdata, xname=None, xdata=None, yname=None,  ydata=None):
         self.add_data(zname=zname, zdata=zdata, xname=xname, xdata=xdata, yname=yname, ydata=ydata, overwrite=True, concat=False)
         print self.pd.get_data(zname)
         xyf=XYFormat(plotter=self)
         xyf.draw_img_plot(name='img_plot', zname=zname, xname=xname, yname=yname)
         self.xyfs.update(**{xyf.name: xyf})
         self.overall_plot_type="img plot"

     def add_line_plot(self, name, zname, zdata, xname='', xdata=None):
        #self.add_data(zname=zname, zdata=zdata, xname=xname, xdata=xdata, overwrite=True)
        self.set_data(zname, zdata)
        self.set_data(xname, xdata)
        xyf=XYFormat(plotter=self)
        zdata=self.get_data(zname)
        #if 1: #zdata.ndim>1:
        #    for i, arr in enumerate(self.splitMultiD(zdata, 0)):
        #        self.add_line_plot(name+str(i), zname+str(i), squeeze(arr), xname, xdata )
        #else:
            #self.set_data(zname, zdata)
            #if xname!=None and xdata!=None:
            #    self.set_data(xname, xdata, coord='x')
        xyf.draw_plot(name=name, zname=zname, xname=xname)
        self.xyfs.update(**{xyf.name: xyf})
        self.overall_plot_type="XY plot"

#     def append_data(self, name, zpoint, xpoint=None):
#         xyf=self.xyfs[name]
#         zdata=self.pd.get_data(xyf.zname)
#         zdata=append(zdata, zpoint)
#         self.pd.set_data(xyf.zname, zdata)
#         xdata=self.pd.get_data(xyf.xname)
#         if xpoint==None:
#             xpoint=max(xdata)+range(len(zpoint))+1
#         xdata=append(xdata, xpoint)
#         self.pd.set_data(xyf.xname, xdata)

#     def _default_plot(self):
#        global Plot, PanTool, ZoomTool, LegendTool
#        if Plot==None:
#            from chaco.plot import Plot
#        if PanTool==None or ZoomTool==None or LegendTool==None:
#            from chaco.tools.api import PanTool, ZoomTool,  LegendTool #, LineInspector
#
#        plot=Plot(self.pd, padding=50, fill_padding=True,
#                        bgcolor="white", use_backbuffer=True,  unified_draw=True)#, use_downsampling=True)
#        plot.tools.append(PanTool(plot, constrain_key="shift"))
#        plot.overlays.append(ZoomTool(component=plot, tool_mode="box", always_on=False))
#        plot.legend.tools.append(LegendTool(plot.legend, drag_button="right"))
#        return plot

     def splitMultiD(self, arr, axis=0):
        if arr.ndim<2:
            return atleast_2d(arr)
        else:
            return split(arr, arr.shape[axis], axis=axis)

     def gatherMultiD(self, name, arrs, appen=None, concat=True, overwrite=False):
         if not isinstance(arrs, tuple):
             arrs=(arrs,)
         if appen==None:
             if shape(arrs)==(1,):
                 appen=True
             else:
                 appen=False             
         orig=self.get_data(name)
         if orig!=None and not overwrite:
             arrs=(orig,)+arrs
         if appen:
             axis=1
         else:
             axis=0
         print arrs[0]==atleast_2d(*arrs)
         #if ndim(arrs[0])>1:
         #    concat=False

         if concat:             
             data=concatenate(atleast_2d(*arrs), axis)
         self.set_data(name, data)
 
     def add_data(self, zname, zdata, xname=None, xdata=None, yname=None, ydata=None, appen=None, concat=True, overwrite=False):
         if xname!=None:
             self.gatherMultiD(xname, xdata, appen=appen, overwrite=overwrite, concat=concat)
         if yname!=None:
             self.gatherMultiD(yname, ydata, appen=appen, overwrite=overwrite, concat=concat)
         self.gatherMultiD(zname, zdata, appen=appen, overwrite=overwrite, concat=concat)
         
     def show(self):
        with imports():
            from e_Plotter import PlotMain
        app = QtApplication()
        view = PlotMain(plotr=self)
        view.show()
        app.start()
コード例 #30
0
class GuessParamModel(Atom):
    """
    This is auto fit model to guess the initial parameters.

    Attributes
    ----------
    parameters : `atom.Dict`
        A list of `Parameter` objects, subclassed from the `Atom` base class.
        These `Parameter` objects hold all relevant xrf information.
    data : array
        1D array of spectrum
    prefit_x : array
        xX axis with range defined by low and high limits.
    result_dict : dict
        Save all the auto fitting results for each element.
        It is a dictionary of object PreFitStatus.
    param_d : dict
        Parameters can be transferred into this dictionary.
    param_new : dict
        More information are saved, such as element position and width.
    total_y : dict
        Results from k lines
    total_y_l : dict
        Results from l lines
    total_y_m : dict
        Results from l lines
    e_list : str
        All elements used for fitting.
    file_path : str
        The path where file is saved.
    element_list : list
    """
    default_parameters = Dict()
    data = Typed(np.ndarray)
    prefit_x = Typed(object)
    result_dict = Typed(object)
    result_dict_names = List()
    param_new = Dict()
    total_y = Typed(object)
    #total_l = Dict()
    #total_m = Dict()
    #total_pileup = Dict()
    e_name = Str()
    add_element_intensity = Float(1000.0)
    element_list = List()
    #data_sets = Typed(OrderedDict)
    EC = Typed(object)
    x0 = Typed(np.ndarray)
    y0 = Typed(np.ndarray)
    max_area_dig = Int(2)
    pileup_data = Dict()
    auto_fit_all = Dict()
    bound_val = Float(1.0)

    def __init__(self, **kwargs):
        try:
            # default parameter is the original parameter, for user to restore
            self.default_parameters = kwargs['default_parameters']
            self.param_new = copy.deepcopy(self.default_parameters)
            self.element_list = get_element(self.param_new)
        except ValueError:
            logger.info('No default parameter files are chosen.')
        self.EC = ElementController()
        self.pileup_data = {'element1': 'Si_K',
                            'element2': 'Si_K',
                            'intensity': 0.0}

    def default_param_update(self, change):
        """
        Observer function to be connected to the fileio model
        in the top-level gui.py startup

        Parameters
        ----------
        changed : dict
            This is the dictionary that gets passed to a function
            with the @observe decorator
        """
        self.default_parameters = change['value']
        self.param_new = copy.deepcopy(self.default_parameters)
        self.element_list = get_element(self.param_new)

    def get_new_param_from_file(self, param_path):
        """
        Update parameters if new param_path is given.

        Parameters
        ----------
        param_path : str
            path to save the file
        """
        with open(param_path, 'r') as json_data:
            self.default_parameters = json.load(json_data)
        self.param_new = copy.deepcopy(self.default_parameters)
        self.element_list = get_element(self.param_new)
        self.EC.delete_all()
        self.define_range()
        self.create_spectrum_from_file(self.param_new, self.element_list)
        logger.info('Elements read from file are: {}'.format(self.element_list))

    def update_new_param(self, param):
        self.default_parameters = param
        self.param_new = copy.deepcopy(self.default_parameters)
        self.element_list = get_element(self.param_new)
        self.EC.delete_all()
        self.define_range()
        self.create_spectrum_from_file(self.param_new, self.element_list)

    def param_changed(self, change):
        """
        Observer function in the top-level gui.py startup

        Parameters
        ----------
        changed : dict
            This is the dictionary that gets passed to a function
            with the @observe decorator
        """
        self.param_new = change['value']

    def exp_data_update(self, change):
        """
        Observer function to be connected to the fileio model
        in the top-level gui.py startup

        Parameters
        ----------
        changed : dict
            This is the dictionary that gets passed to a function
            with the @observe decorator
        """
        self.data = change['value']

    @observe('bound_val')
    def _update_bound(self, change):
        if change['type'] != 'create':
            logger.info('Values smaller than bound {} can be cutted on Auto peak finding.'.format(self.bound_val))

    def define_range(self):
        """
        Cut x range according to values define in param_dict.
        """
        lowv = self.param_new['non_fitting_values']['energy_bound_low']['value']
        highv = self.param_new['non_fitting_values']['energy_bound_high']['value']
        self.x0, self.y0 = define_range(self.data, lowv, highv,
                                        self.param_new['e_offset']['value'],
                                        self.param_new['e_linear']['value'])

    def create_spectrum_from_file(self, param_dict, elemental_lines):
        """
        Create spectrum profile with given param dict from file.

        Parameters
        ----------
        param_dict : dict
            dict obtained from file
        elemental_lines : list
            e.g., ['Na_K', Mg_K', 'Pt_M'] refers to the
            K lines of Sodium, the K lines of Magnesium, and the M
            lines of Platinum
        """
        self.prefit_x, pre_dict, area_dict = calculate_profile(self.x0,
                                                               self.y0,
                                                               param_dict,
                                                               elemental_lines)
        # add escape peak
        if param_dict['non_fitting_values']['escape_ratio'] > 0:
            pre_dict['escape'] = trim_escape_peak(self.data,
                                                  param_dict, len(self.y0))

        temp_dict = OrderedDict()
        for e in six.iterkeys(pre_dict):
            if e in ['background', 'escape']:
                spectrum = pre_dict[e]

                # summed spectrum here is not correct,
                # as the interval is assumed as 1, not energy interval
                # however area of background and escape is not used elsewhere, not important
                area = np.sum(spectrum)

                ps = PreFitStatus(z=get_Z(e), energy=get_energy(e),
                                  area=float(area), spectrum=spectrum,
                                  maxv=float(np.around(np.max(spectrum), self.max_area_dig)),
                                  norm=-1, lbd_stat=False)
                temp_dict[e] = ps

            elif '-' in e:  # pileup peaks
                e1, e2 = e.split('-')
                energy = float(get_energy(e1))+float(get_energy(e2))
                spectrum = pre_dict[e]
                area = area_dict[e]

                ps = PreFitStatus(z=get_Z(e), energy=str(energy),
                                  area=area, spectrum=spectrum,
                                  maxv=np.around(np.max(spectrum), self.max_area_dig),
                                  norm=-1, lbd_stat=False)
                temp_dict[e] = ps

            else:
                ename = e.split('_')[0]
                for k, v in six.iteritems(param_dict):
                    if ename in k and 'area' in k:
                        spectrum = pre_dict[e]
                        area = area_dict[e]

                    elif ename == 'compton' and k == 'compton_amplitude':
                        spectrum = pre_dict[e]
                        area = area_dict[e]

                    elif ename == 'elastic' and k == 'coherent_sct_amplitude':
                        spectrum = pre_dict[e]
                        area = area_dict[e]

                    else:
                        continue

                    ps = PreFitStatus(z=get_Z(ename), energy=get_energy(e),
                                      area=area, spectrum=spectrum,
                                      maxv=np.around(np.max(spectrum), self.max_area_dig),
                                      norm=-1, lbd_stat=False)

                    temp_dict[e] = ps
        self.EC.add_to_dict(temp_dict)

    def manual_input(self):
        default_area = 1e2

        # if self.e_name == 'escape':
        #     self.param_new['non_fitting_values']['escape_ratio'] = (self.add_element_intensity
        #                                                             / np.max(self.y0))
        #     es_peak = trim_escape_peak(self.data, self.param_new,
        #                                len(self.y0))
        #     ps = PreFitStatus(z=get_Z(self.e_name),
        #                       energy=get_energy(self.e_name),
        #                       # put float in front of area and maxv
        #                       # due to type conflicts in atom, which regards them as
        #                       # np.float32 if we do not put float in front.
        #                       area=float(np.around(np.sum(es_peak), self.max_area_dig)),
        #                       spectrum=es_peak,
        #                       maxv=float(np.around(np.max(es_peak), self.max_area_dig)),
        #                       norm=-1, lbd_stat=False)
        #     logger.info('{} peak is added'.format(self.e_name))
        #
        # else:
        x, data_out, area_dict = calculate_profile(self.x0,
                                                   self.y0,
                                                   self.param_new,
                                                   elemental_lines=[self.e_name],
                                                   default_area=default_area)

        ratio_v = self.add_element_intensity / np.max(data_out[self.e_name])

        ps = PreFitStatus(z=get_Z(self.e_name),
                          energy=get_energy(self.e_name),
                          area=area_dict[self.e_name]*ratio_v,
                          spectrum=data_out[self.e_name]*ratio_v,
                          maxv=self.add_element_intensity,
                          norm=-1,
                          status=True,    # for plotting
                          lbd_stat=False)

        self.EC.add_to_dict({self.e_name: ps})

    def add_pileup(self):
        default_area = 1e2
        if self.pileup_data['intensity'] != 0:
            e_name = (self.pileup_data['element1'] + '-'
                      + self.pileup_data['element2'])
            # parse elemental lines into multiple lines

            x, data_out, area_dict = calculate_profile(self.x0,
                                                       self.y0,
                                                       self.param_new,
                                                       elemental_lines=[e_name],
                                                       default_area=default_area)
            energy = str(float(get_energy(self.pileup_data['element1']))
                         + float(get_energy(self.pileup_data['element2'])))

            ratio_v = self.pileup_data['intensity'] / np.max(data_out[e_name])

            ps = PreFitStatus(z=get_Z(e_name),
                              energy=energy,
                              area=area_dict[e_name]*ratio_v,
                              spectrum=data_out[e_name]*ratio_v,
                              maxv=self.pileup_data['intensity'],
                              norm=-1,
                              status=True,    # for plotting
                              lbd_stat=False)
            logger.info('{} peak is added'.format(e_name))
        self.EC.add_to_dict({e_name: ps})

    def update_name_list(self):
        """
        When result_dict_names change, the looper in enaml will update.
        """
        # need to clean list first, in order to refresh the list in GUI
        self.result_dict_names = []
        self.result_dict_names = list(self.EC.element_dict.keys())
        #logger.info('The full list for fitting is {}'.format(self.result_dict_names))

    def find_peak(self, threshv=0.1):
        """
        Run automatic peak finding, and save results as dict of object.

        Parameters
        ----------
        threshv : float
            The value will not be shown on GUI if it is smaller than the threshold.
        """
        self.define_range()  # in case the energy calibraiton changes
        self.prefit_x, out_dict, area_dict = linear_spectrum_fitting(self.x0,
                                                                     self.y0,
                                                                     self.param_new)
        logger.info('Energy range: {}, {}'.format(
            self.param_new['non_fitting_values']['energy_bound_low']['value'],
            self.param_new['non_fitting_values']['energy_bound_high']['value']))

        prefit_dict = OrderedDict()
        for k, v in six.iteritems(out_dict):
            ps = PreFitStatus(z=get_Z(k),
                              energy=get_energy(k),
                              area=area_dict[k],
                              spectrum=v,
                              maxv=np.around(np.max(v), self.max_area_dig),
                              norm=-1,
                              lbd_stat=False)
            prefit_dict.update({k: ps})

        logger.info('Automatic Peak Finding found elements as : {}'.format(
            list(prefit_dict.keys())))
        self.EC.delete_all()
        self.EC.add_to_dict(prefit_dict)

    def create_full_param(self):
        """
        Extend the param to full param dict including each element's
        information, and assign initial values from pre fit.
        """
        self.define_range()
        self.element_list = self.EC.get_element_list()
        # self.param_new['non_fitting_values']['element_list'] = ', '.join(self.element_list)
        #
        # # first remove some nonexisting elements
        # # remove elements not included in self.element_list
        # self.param_new = param_dict_cleaner(self.param_new,
        #                                     self.element_list)
        #
        # # second add some elements to a full parameter dict
        # # create full parameter list including elements
        # PC = ParamController(self.param_new, self.element_list)
        # # parameter values not updated based on param_new, so redo it
        # param_temp = PC.params
        # for k, v in six.iteritems(param_temp):
        #     if k == 'non_fitting_values':
        #         continue
        #     if self.param_new.has_key(k):
        #         v['value'] = self.param_new[k]['value']
        # self.param_new = param_temp
        #
        # # to create full param dict, for GUI only
        # create_full_dict(self.param_new, fit_strategy_list)

        self.param_new = update_param_from_element(self.param_new, self.element_list)
        element_temp = [e for e in self.element_list if len(e) <= 4]
        pileup_temp = [e for e in self.element_list if '-' in e]
        userpeak_temp = [e for e in self.element_list if 'user' in e.lower()]

        # update area values in param_new according to results saved in ElementController
        if len(self.EC.element_dict):
            for k, v in six.iteritems(self.param_new):
                if 'area' in k:
                    if 'pileup' in k:
                        name_cut = k[7:-5]  #remove pileup_ and _area
                        for p in pileup_temp:
                            if name_cut == p.replace('-', '_'):
                                v['value'] = self.EC.element_dict[p].area
                    elif 'user' in k.lower():
                        for p in userpeak_temp:
                            if p in k:
                                v['value'] = self.EC.element_dict[p].area
                    else:
                        for e in element_temp:
                            k_name, k_line, _ = k.split('_')
                            e_name, e_line = e.split('_')
                            if k_name == e_name and e_line.lower() == k_line[0]:  # attention: S_k and As_k
                                v['value'] = self.EC.element_dict[e].area

            if 'compton' in self.EC.element_dict:
                self.param_new['compton_amplitude']['value'] = self.EC.element_dict['compton'].area
            if 'coherent_sct_amplitude' in self.EC.element_dict:
                self.param_new['coherent_sct_amplitude']['value'] = self.EC.element_dict['elastic'].area

            if 'escape' in self.EC.element_dict:
                self.param_new['non_fitting_values']['escape_ratio'] = (self.EC.element_dict['escape'].maxv
                                                                        / np.max(self.y0))
            else:
                self.param_new['non_fitting_values']['escape_ratio'] = 0.0

    def data_for_plot(self):
        """
        Save data in terms of K, L, M lines for plot.
        """
        self.total_y = None
        self.auto_fit_all = {}

        for k, v in six.iteritems(self.EC.element_dict):
            if v.status is True:
                self.auto_fit_all[k] = v.spectrum
                if self.total_y is None:
                    self.total_y = np.array(v.spectrum)  # need to copy an array
                else:
                    self.total_y += v.spectrum