Пример #1
0
class ScanModel(Model):
    points_rewritten = QtCore.pyqtSignal(dict)
    points_appended = QtCore.pyqtSignal(dict)
    annotations_changed = QtCore.pyqtSignal(list)

    def __init__(self, axes: List[Dict[str, Any]], context: Context):
        super().__init__(context)
        self.axes = axes
        self._annotations = []
        self._annotation_schemata = []
        self._online_analyses = {}

    def get_point_data(self) -> Dict[str, Any]:
        raise NotImplementedError

    def get_annotations(self) -> List[Annotation]:
        return self._annotations

    def _set_annotation_schemata(self, schemata: List[Dict[str, Any]]):
        self._annotation_schemata = schemata
        self._annotations = []

        def data_source(spec):
            kind = spec["kind"]
            if kind == "fixed":
                return FixedDataSource(spec["value"])
            if kind == "analysis_result":
                analysis = self._online_analyses.get(spec["analysis_name"], None)
                if analysis is None:
                    return None
                return OnlineAnalysisDataSource(analysis, spec["result_key"])
            logger.info("Ignoring unsupported annotation data source type: '%s'", kind)
            return None

        def to_data_sources(specs):
            return {k: data_source(v) for k, v in specs.items()}

        for schema in schemata:
            sources = [to_data_sources(schema.get(n)) for n in ("coordinates", "data")]
            self._annotations.append(
                Annotation(schema["kind"], schema.get("parameters", {}), *sources))
        self.annotations_changed.emit(self._annotations)

    def _set_online_analyses(self,
                             analysis_schemata: Dict[str, Dict[str, Any]]) -> None:
        for a in self._online_analyses.values():
            a.stop()
        self._online_analyses = {}

        for name, schema in analysis_schemata.items():
            kind = schema["kind"]
            if kind == "named_fit":
                self._online_analyses[name] = OnlineNamedFitAnalysis(schema, self)
            else:
                logger.warning("Ignoring unsupported online analysis type: '%s'", kind)

        # Rebind annotation schemata to new analysis data sources.
        self._set_annotation_schemata(self._annotation_schemata)
Пример #2
0
class Context(QtCore.QObject):
    """Describes the environment in which a certain plot is displayed.

    This is the moral equivalent of a container for global variables and should be used
    only sparsely (i.e. for actual properties of the environment).
    """

    source_id_changed = QtCore.pyqtSignal(str)
    title_changed = QtCore.pyqtSignal(str)

    def __init__(self, set_dataset: Callable[[str, Any], None] = None):
        super().__init__()
        self._set_dataset = set_dataset
        self._title = ""
        self._source_id = "<unknown>"

    def get_title(self) -> str:
        return self._title

    def set_title(self, title: str) -> None:
        if self._title != title:
            self._title = title
            self.title_changed.emit(title)

    def get_source_id(self):
        """Return a short string that helps the user to identify the data source.

        This is usually based on the run id, and shown in plots for data traceability
        purposes.
        """
        return self._source_id

    def set_source_id(self, source_id):
        if self._source_id != source_id:
            self._source_id = source_id
            self.source_id_changed.emit(source_id)

    def is_online_master(self) -> bool:
        """Return whether the plot is run in an environment where there is a connection
        to an ARTIQ master (as opposed to e.g. displaying an offline results file).
        """
        return self._set_dataset is not None

    def set_dataset(self, key: str, value: Any) -> None:
        """Sets dataset ``key`` to ``value`` on the connected master, if any.

        See: :meth:`is_online_master`.
        """
        self._set_dataset(key, value)
Пример #3
0
class Model(QtCore.QObject):
    channel_schemata_changed = QtCore.pyqtSignal(dict)

    def __init__(self, context: Context):
        super().__init__()
        self.context = context

    def get_channel_schemata(self) -> Dict[str, Any]:
        raise NotImplementedError
Пример #4
0
class Root(QtCore.QObject):
    """The root of a plot data tree, i.e. all the data making up a plot displayed in a
    given window.

    This is a reference to (at most) one :class:`Model`, but makes it possible to
    represent situations where the model is not yet known (e.g. because we are still
    waiting for the experiment to set the top-level metadata datasets), or might change
    (because we are showing a subscan for a user-selected point).
    """

    model_changed = QtCore.pyqtSignal(object)

    def get_model(self) -> Union["Model", None]:
        raise NotImplementedError
Пример #5
0
class Context(QtCore.QObject):
    title_changed = QtCore.pyqtSignal(str)

    def __init__(self, set_dataset: Callable[[str, Any], None] = None):
        super().__init__()
        self._set_dataset = set_dataset
        self.title = ""

    def set_title(self, title: str) -> None:
        if title != self.title:
            self.title = title
            self.title_changed.emit(title)

    def is_online_master(self) -> bool:
        return self.set_dataset is not None

    def set_dataset(self, key: str, value: Any) -> None:
        self._set_dataset(key, value)
Пример #6
0
class AlternateMenuPlotWidget(ContextMenuPlotWidget):
    """PlotWidget with context menu for integration with the
    .container.PlotContainerWidget alternate plot switching functionality."""

    alternate_plot_requested = QtCore.pyqtSignal(str)

    def __init__(self, get_alternate_plot_names):
        super().__init__()
        self._get_alternate_plot_names = get_alternate_plot_names

    def build_context_menu(self, builder: ContextMenuBuilder) -> None:
        alternate_plot_names = self._get_alternate_plot_names()
        if len(alternate_plot_names) > 1:
            for name in alternate_plot_names:
                action = builder.append_action("Show " + name)
                action.triggered.connect(lambda *args, name=name: self.
                                         alternate_plot_requested.emit(name))
        builder.ensure_separator()
Пример #7
0
class _XYSeries(QtCore.QObject):
    def __init__(self,
                 plot,
                 data_name,
                 data_item,
                 error_bar_name,
                 error_bar_item,
                 plot_left_to_right,
                 fit_spec=None,
                 fit_item=None,
                 fit_pois=[]):
        super().__init__(plot)

        self.plot = plot
        self.data_item = data_item
        self.data_name = data_name
        self.error_bar_item = error_bar_item
        self.error_bar_name = error_bar_name
        self.plot_left_to_right = plot_left_to_right
        self.num_current_points = 0
        self.fit_obj = None

        if fit_spec:
            self._install_fit(fit_spec, fit_item, fit_pois)

    def update(self, x_data, data):
        def channel(name):
            return data.get("ndscan.points.channel_" + name, (False, []))[1]

        y_data = channel(self.data_name)
        num_to_show = min(len(x_data), len(y_data))

        if self.error_bar_item:
            y_err = channel(self.error_bar_name)
            num_to_show = min(num_to_show, len(y_err))

        if num_to_show == self.num_current_points:
            return

        if self.plot_left_to_right:
            x_data = np.array(x_data)
            order = np.argsort(x_data[:num_to_show])

            y_data = np.array(y_data)
            self.data_item.setData(x_data[order], y_data[order])
            if self.num_current_points == 0:
                self.plot.addItem(self.data_item)

            if self.error_bar_item:
                y_err = np.array(y_err)
                self.error_bar_item.setData(x=x_data[order],
                                            y=y_data[order],
                                            height=y_err[order])
                if self.num_current_points == 0:
                    self.plot.addItem(self.error_bar_item)
        else:
            self.data_item.setData(x_data[:num_to_show], y_data[:num_to_show])
            if self.num_current_points == 0:
                self.plot.addItem(self.data_item)

            if self.error_bar_item:
                self.error_bar_item.setData(
                    x=x_data[:num_to_show],
                    y=y_data[:num_to_show],
                    height=(2 * np.array(y_err[:num_to_show])))
                if self.num_current_points == 0:
                    self.plot.addItem(self.error_bar_item)

        self.num_current_points = num_to_show

        if self.fit_obj and self.num_current_points >= len(
                self.fit_obj.parameter_names):
            self._trigger_recompute_fit.emit()

    _trigger_recompute_fit = QtCore.pyqtSignal()

    def _install_fit(self, spec, item, pois):
        self.fit_type = spec["fit_type"]
        self.fit_obj = FIT_OBJECTS[self.fit_type]
        self.fit_item = item
        self.fit_pois = pois
        self.fit_item_added = False

        self.last_fit_params = None

        self.recompute_fit_limiter = pyqtgraph.SignalProxy(
            self._trigger_recompute_fit,
            slot=lambda: asyncio.ensure_future(self._recompute_fit()),
            rateLimit=30)
        self.recompute_in_progress = False
        self.fit_executor = ProcessPoolExecutor(max_workers=1)

        self.redraw_fit_limiter = pyqtgraph.SignalProxy(
            self.plot.getPlotItem().getViewBox().sigXRangeChanged,
            slot=self._redraw_fit,
            rateLimit=30)

    async def _recompute_fit(self):
        if self.recompute_in_progress:
            # Run at most one fit computation at a time. To make sure we don't
            # leave a few final data points completely disregarded, just
            # re-emit the signal – even for long fits, repeated checks aren't
            # expensive, as long as the SignalProxy rate is slow enough.
            self._trigger_recompute_fit.emit()
            return

        self.recompute_in_progress = True

        xs, ys = self.data_item.getData()
        y_errs = None
        if self.error_bar_item:
            y_errs = self.error_bar_item.opts['height'] / 2

        loop = asyncio.get_event_loop()
        self.last_fit_params, self.last_fit_errors = await loop.run_in_executor(
            self.fit_executor, _run_fit, self.fit_type, xs, ys, y_errs)
        self.redraw_fit_limiter.signalReceived()

        self.recompute_in_progress = False

    def _redraw_fit(self, *args):
        if not self.last_fit_params:
            return

        if not self.fit_item_added:
            self.plot.addItem(self.fit_item, ignoreBounds=True)
            for f in self.fit_pois:
                f.add_to_plot(self.plot)
            self.fit_item_added = True

        # Choose horizontal range based on currently visible area.
        view_box = self.plot.getPlotItem().getViewBox()
        x_range, _ = view_box.state["viewRange"]
        ext = (x_range[1] - x_range[0]) / 10
        x_lims = (x_range[0] - ext, x_range[1] + ext)

        # Choose number of points based on width of plot on screen (in pixels).
        fit_xs = np.linspace(*x_lims, view_box.width())

        fit_ys = self.fit_obj.fitting_function(fit_xs, self.last_fit_params)

        self.fit_item.setData(fit_xs, fit_ys)

        for f in self.fit_pois:
            f.update(self.last_fit_params, self.last_fit_errors)
Пример #8
0
class AnnotationDataSource(QtCore.QObject):
    changed = QtCore.pyqtSignal()

    def get(self) -> Any:
        raise NotImplementedError
Пример #9
0
class Rolling1DPlotWidget(pyqtgraph.PlotWidget):
    error = QtCore.pyqtSignal(str)

    def __init__(self):
        super().__init__()

        self.series_initialised = False
        self.series = []

        self.point_phase = None

        self.showGrid(x=True, y=True)

        self._install_context_menu()

    def data_changed(self, data, mods):
        def d(name):
            return data.get("ndscan." + name, (False, None))[1]

        if not self.series_initialised:
            channels_json = d("channels")
            if not channels_json:
                return

            channels = json.loads(channels_json)

            try:
                data_names, error_bar_names = extract_scalar_channels(channels)
            except ValueError as e:
                self.error.emit(str(e))

            for i, data_name in enumerate(data_names):
                color = SERIES_COLORS[i % len(SERIES_COLORS)]
                data_item = pyqtgraph.ScatterPlotItem(pen=None, brush=color)

                error_bar_name = error_bar_names.get(data_name, None)
                error_bar_item = pyqtgraph.ErrorBarItem(
                    pen=color) if error_bar_name else None

                self.series.append(
                    _Series(self, data_name, data_item, error_bar_name,
                            error_bar_item, self.num_history_box.value()))

            if len(data_names) == 1:
                # If there is only one series, set label/scaling accordingly.
                # TODO: Add multiple y axis for additional channels.
                c = channels[data_names[0]]

                label = c["description"]
                if not label:
                    label = c["path"].split("/")[-1]
                setup_axis_item(self.getAxis("left"), label, c["path"], c)

            self.series_initialised = True

        # FIXME: Phase check will miss points when using mod buffering - need to
        # directly read all the data from mods.
        phase = d("point_phase")
        if phase is not None and phase != self.point_phase:
            for s in self.series:
                s.append(data)
            self.point_phase = phase

    def set_history_length(self, n):
        for s in self.series:
            s.set_history_length(n)

    def _install_context_menu(self):
        self.num_history_box = QtWidgets.QSpinBox()
        self.num_history_box.setMinimum(1)
        self.num_history_box.setMaximum(2**16)
        self.num_history_box.setValue(100)
        self.num_history_box.valueChanged.connect(self.set_history_length)

        container = QtWidgets.QWidget()

        layout = QtWidgets.QHBoxLayout()
        container.setLayout(layout)

        label = QtWidgets.QLabel("N: ")
        layout.addWidget(label)

        layout.addWidget(self.num_history_box)

        action = QtWidgets.QWidgetAction(self)
        action.setDefaultWidget(container)

        separator = QtWidgets.QAction("", self)
        separator.setSeparator(True)
        entries = [action, separator]
        self.plotItem.getContextMenus = lambda ev: entries
Пример #10
0
class SinglePointModel(Model):
    point_changed = QtCore.pyqtSignal(object)

    def get_point(self) -> Union[None, Dict[str, Any]]:
        raise NotImplementedError
Пример #11
0
class Image2DPlotWidget(pyqtgraph.PlotWidget):
    error = QtCore.pyqtSignal(str)

    def __init__(self, x_schema, y_schema, set_dataset):
        super().__init__()
        self.x_schema = x_schema
        self.y_schema = y_schema

        self.set_dataset = set_dataset
        self.plot = None

        def setup_axis(schema, location):
            path = schema["path"]
            if not path:
                path = "/"
            identity_string = schema["param"]["fqn"] + "@" + path
            return setup_axis_item(self.getAxis(location),
                                   schema["param"]["description"],
                                   identity_string, schema["param"]["spec"])

        self.x_unit_suffix, self.x_data_to_display_scale = \
            setup_axis(x_schema, "bottom")
        self.y_unit_suffix, self.y_data_to_display_scale = \
            setup_axis(y_schema, "left")

        self.crosshair = LabeledCrosshairCursor(self, self, self.x_unit_suffix,
                                                self.x_data_to_display_scale,
                                                self.y_unit_suffix,
                                                self.y_data_to_display_scale)
        self.showGrid(x=True, y=True)

    def data_changed(self, datasets, mods):
        def d(name):
            return datasets.get("ndscan." + name, (False, None))[1]

        if not self.plot:
            channels_json = d("channels")
            if not channels_json:
                return

            channels = json.loads(channels_json)

            try:
                data_names, _ = extract_scalar_channels(channels)
            except ValueError as e:
                self.error.emit(str(e))

            if not data_names:
                self.error.emit("No scalar result channels to display")

            hints_for_channels = {
                name: channels[name].get("display_hints", {})
                for name in data_names
            }
            self._install_context_menu(data_names)

            def bounds(schema):
                return (schema.get(n, None)
                        for n in ("min", "max", "increment"))

            image_item = pyqtgraph.ImageItem()
            self.addItem(image_item)
            self.plot = _ImagePlot(image_item, data_names[0],
                                   *bounds(self.x_schema),
                                   *bounds(self.y_schema), hints_for_channels)

        self.plot.data_changed(datasets)

    def _install_context_menu(self, data_names):
        entries = []

        x_datasets = extract_linked_datasets(self.x_schema["param"])
        y_datasets = extract_linked_datasets(self.y_schema["param"])
        for d, axis in chain(zip(x_datasets, repeat("x")),
                             zip(y_datasets, repeat("y"))):
            action = QtWidgets.QAction("Set '{}' from crosshair".format(d),
                                       self)
            action.triggered.connect(
                lambda *a, d=d: self._set_dataset_from_crosshair(d, axis))
            entries.append(action)
        if len(x_datasets) == 1 and len(y_datasets) == 1:
            action = QtWidgets.QAction("Set both from crosshair".format(d),
                                       self)

            def set_both():
                self._set_dataset_from_crosshair(x_datasets[0], "x")
                self._set_dataset_from_crosshair(y_datasets[0], "y")

            action.triggered.connect(set_both)
            entries.append(action)

        def append_separator():
            separator = QtWidgets.QAction("", self)
            separator.setSeparator(True)
            entries.append(separator)

        if entries:
            append_separator()

        self.channel_menu_group = QtWidgets.QActionGroup(self)
        first_action = None
        for name in data_names:
            action = QtWidgets.QAction(name, self)
            if not first_action:
                first_action = action
            action.setCheckable(True)
            action.setActionGroup(self.channel_menu_group)
            action.triggered.connect(
                lambda *a, name=name: self.plot.activate_channel(name))
            entries.append(action)
        if first_action:
            first_action.setChecked(True)
        append_separator()

        self.plotItem.getContextMenus = lambda ev: entries

    def _set_dataset_from_crosshair(self, dataset, axis):
        if not self.crosshair:
            logger.warning(
                "Plot not initialised yet, ignoring set dataset request")
            return
        self.set_dataset(
            dataset,
            self.crosshair.last_x if axis == "x" else self.crosshair.last_y)
Пример #12
0
class XY1DPlotWidget(SubplotMenuPlotWidget):
    error = QtCore.pyqtSignal(str)
    ready = QtCore.pyqtSignal()

    def __init__(self, model: ScanModel, get_alternate_plot_names):
        super().__init__(model.context, get_alternate_plot_names)

        self.model = model
        self.model.channel_schemata_changed.connect(self._initialise_series)
        self.model.points_appended.connect(self._update_points)
        self.model.annotations_changed.connect(self._update_annotations)

        # FIXME: Just re-set values instead of throwing away everything.
        def rewritten(points):
            self._initialise_series(self.model.get_channel_schemata())
            self._update_points(points)

        self.model.points_rewritten.connect(rewritten)

        self.selected_point_model = SelectPointFromScanModel(self.model)

        self.annotation_items = []
        self.series = []

        x_schema = self.model.axes[0]
        self.x_param_spec = x_schema["param"]["spec"]
        self.x_unit_suffix, self.x_data_to_display_scale = setup_axis_item(
            self.getAxis("bottom"),
            [(x_schema["param"]["description"],
              format_param_identity(x_schema), None, self.x_param_spec)])
        self.y_unit_suffix = None
        self.y_data_to_display_scale = None
        self.crosshair = None
        self._highlighted_spot = None
        self.showGrid(x=True, y=True)

        view_box = self.getPlotItem().getViewBox()
        self.source_label = add_source_id_label(view_box, self.model.context)

        view_box.scene().sigMouseClicked.connect(self._handle_scene_click)

    def _initialise_series(self, channels):
        # Remove all currently shown items and any extra axes added.
        for s in self.series:
            s.remove_items()
        self.series.clear()
        self._clear_annotations()
        self.reset_y_axes()

        try:
            data_names, error_bar_names = extract_scalar_channels(channels)
        except ValueError as e:
            self.error.emit(str(e))
            return

        series_idx = 0
        axes = group_channels_into_axes(channels, data_names)
        for names in axes:
            axis, view_box = self.new_y_axis()

            info = []
            for name in names:
                color = SERIES_COLORS[series_idx % len(SERIES_COLORS)]
                data_item = pyqtgraph.ScatterPlotItem(pen=None,
                                                      brush=color,
                                                      size=6)
                data_item.sigClicked.connect(self._point_clicked)

                error_bar_item = None
                error_bar_name = error_bar_names.get(name, None)
                if error_bar_name:
                    error_bar_item = pyqtgraph.ErrorBarItem(pen=color)

                self.series.append(
                    _XYSeries(view_box, name, data_item, error_bar_name,
                              error_bar_item, False))

                channel = channels[name]
                label = channel["description"]
                if not label:
                    label = channel["path"].split("/")[-1]
                info.append((label, channel["path"], color, channel))

                series_idx += 1

            suffix, scale = setup_axis_item(axis, info)
            if self.y_unit_suffix is None:
                # FIXME: Add multiple lines to the crosshair.
                self.y_unit_suffix = suffix
                self.y_data_to_display_scale = scale

        if self.crosshair is None:
            # FIXME: Reinitialise crosshair as necessary on schema changes.
            self.crosshair = LabeledCrosshairCursor(
                self, self.getPlotItem(), self.x_unit_suffix,
                self.x_data_to_display_scale, self.y_unit_suffix,
                self.y_data_to_display_scale)
        self.subscan_roots = create_subscan_roots(self.selected_point_model)

        # Make sure we put back annotations (if they haven't changed but the points
        # have been rewritten, there might not be an annotations_changed event).
        self._update_annotations()

        self.ready.emit()

    def _update_points(self, points):
        x_data = points["axis_0"]
        # Compare length to zero instead of using `not x_data` for NumPy array
        # compatibility.
        if len(x_data) == 0:
            return

        for s in self.series:
            s.update(x_data, points)

    def _clear_annotations(self):
        for item in self.annotation_items:
            item.remove()
        self.annotation_items.clear()

    def _update_annotations(self):
        self._clear_annotations()

        def channel_ref_to_series_idx(ref):
            for i, s in enumerate(self.series):
                if "channel_" + s.data_name == ref:
                    return i
            return 0

        def make_curve_item(series_idx):
            color = FIT_COLORS[series_idx % len(FIT_COLORS)]
            pen = pyqtgraph.mkPen(color, width=3)
            return pyqtgraph.PlotCurveItem(pen=pen)

        annotations = self.model.get_annotations()
        for a in annotations:
            if a.kind == "location":
                if set(a.coordinates.keys()) == set(["axis_0"]):
                    associated_series_idx = max(
                        channel_ref_to_series_idx(chan) for chan in
                        a.parameters.get("associated_channels", [None]))

                    color = FIT_COLORS[associated_series_idx % len(FIT_COLORS)]
                    vb = self.series[associated_series_idx].view_box
                    line = VLineItem(a.coordinates["axis_0"],
                                     a.data.get("axis_0_error", None), vb,
                                     color, self.x_data_to_display_scale,
                                     self.x_unit_suffix)
                    self.annotation_items.append(line)
                    continue

            if a.kind == "curve":
                associated_series_idx = None
                for series_idx, series in enumerate(self.series):
                    match_coords = set(
                        ["axis_0", "channel_" + series.data_name])
                    if set(a.coordinates.keys()) == match_coords:
                        associated_series_idx = series_idx
                        break
                if associated_series_idx is not None:
                    curve = make_curve_item(associated_series_idx)
                    series = self.series[associated_series_idx]
                    vb = series.view_box
                    item = CurveItem(
                        a.coordinates["axis_0"],
                        a.coordinates["channel_" + series.data_name], vb,
                        curve)
                    self.annotation_items.append(item)
                    continue

            if a.kind == "computed_curve":
                function_name = a.parameters.get("function_name", None)
                if ComputedCurveItem.is_function_supported(function_name):
                    associated_series_idx = max(
                        channel_ref_to_series_idx(chan) for chan in
                        a.parameters.get("associated_channels", [None]))

                    x_limits = [
                        self.x_param_spec.get(n, None) for n in ("min", "max")
                    ]
                    curve = make_curve_item(associated_series_idx)
                    vb = self.series[associated_series_idx].view_box
                    item = ComputedCurveItem(function_name, a.data, vb, curve,
                                             x_limits)
                    self.annotation_items.append(item)
                    continue

            logger.info("Ignoring annotation of kind '%s' with coordinates %s",
                        a.kind, list(a.coordinates.keys()))

    def build_context_menu(self, builder):
        x_schema = self.model.axes[0]

        if self.model.context.is_online_master():
            for d in extract_linked_datasets(x_schema["param"]):
                action = builder.append_action(
                    "Set '{}' from crosshair".format(d))
                action.triggered.connect(
                    lambda: self._set_dataset_from_crosshair_x(d))

        builder.ensure_separator()
        super().build_context_menu(builder)

    def _set_dataset_from_crosshair_x(self, dataset_key):
        if not self.crosshair:
            logger.warning(
                "Plot not initialised yet, ignoring set dataset request")
            return
        self.model.context.set_dataset(dataset_key, self.crosshair.last_x)

    def _highlight_spot(self, spot):
        if self._highlighted_spot is not None:
            self._highlighted_spot.resetPen()
            self._highlighted_spot = None
        if spot is not None:
            spot.setPen("y", width=2)
            self._highlighted_spot = spot

    def _point_clicked(self, scatter_plot_item, spot_items):
        if not spot_items:
            # No points clicked – events don't seem to emitted in this case anyway.
            self._background_clicked()
            return

        # Arbitrarily choose the first element in the list if multiple spots
        # overlap; the user can always zoom in if that is undesired.
        spot = spot_items[0]
        self._highlight_spot(spot)
        self.selected_point_model.set_source_index(spot.index())

    def _background_clicked(self):
        self._highlight_spot(None)
        self.selected_point_model.set_source_index(None)

    def _handle_scene_click(self, event):
        if not event.isAccepted():
            # Event not handled yet, so background/… was clicked instead of a point.
            self._background_clicked()
Пример #13
0
class Image2DPlotWidget(AlternateMenuPlotWidget):
    error = QtCore.pyqtSignal(str)
    ready = QtCore.pyqtSignal()

    def __init__(self, model: ScanModel, get_alternate_plot_names):
        super().__init__(get_alternate_plot_names)

        self.model = model
        self.model.channel_schemata_changed.connect(self._initialise_series)
        self.model.points_appended.connect(
            lambda p: self._update_points(p, False))
        self.model.points_rewritten.connect(
            lambda p: self._update_points(p, True))

        self.data_names = []

        self.x_schema, self.y_schema = self.model.axes
        self.plot = None

        def setup_axis(schema, location):
            param = schema["param"]
            return setup_axis_item(
                self.getAxis(location),
                [(param["description"], format_param_identity(schema), None,
                  param["spec"])])

        self.x_unit_suffix, self.x_data_to_display_scale = \
            setup_axis(self.x_schema, "bottom")
        self.y_unit_suffix, self.y_data_to_display_scale = \
            setup_axis(self.y_schema, "left")

        self.crosshair = LabeledCrosshairCursor(self, self.getPlotItem(),
                                                self.x_unit_suffix,
                                                self.x_data_to_display_scale,
                                                self.y_unit_suffix,
                                                self.y_data_to_display_scale)
        self.showGrid(x=True, y=True)

        self.source_label = add_source_id_label(
            self.getPlotItem().getViewBox(), self.model.context)

    def _initialise_series(self, channels):
        if self.plot is not None:
            self.removeItem(self.plot.image_item)
            self.plot = None

        try:
            self.data_names, _ = extract_scalar_channels(channels)
        except ValueError as e:
            self.error.emit(str(e))

        if not self.data_names:
            self.error.emit("No scalar result channels to display")

        hints_for_channels = {
            name: channels[name].get("display_hints", {})
            for name in self.data_names
        }

        def bounds(schema):
            return (schema.get(n, None) for n in ("min", "max", "increment"))

        image_item = pyqtgraph.ImageItem()
        self.addItem(image_item)
        self.plot = _ImagePlot(image_item, self.data_names[0],
                               *bounds(self.x_schema), *bounds(self.y_schema),
                               hints_for_channels)
        self.ready.emit()

    def _update_points(self, points, invalidate):
        if self.plot:
            self.plot.data_changed(points, invalidate_previous=invalidate)

    def build_context_menu(self, builder):
        if self.model.context.is_online_master():
            x_datasets = extract_linked_datasets(self.x_schema["param"])
            y_datasets = extract_linked_datasets(self.y_schema["param"])
            for d, axis in chain(zip(x_datasets, repeat("x")),
                                 zip(y_datasets, repeat("y"))):
                action = builder.append_action(
                    "Set '{}' from crosshair".format(d))
                action.triggered.connect(lambda *a, axis=axis, d=d: (
                    self._set_dataset_from_crosshair(d, axis)))
            if len(x_datasets) == 1 and len(y_datasets) == 1:
                action = builder.append_action("Set both from crosshair")

                def set_both():
                    self._set_dataset_from_crosshair(x_datasets[0], "x")
                    self._set_dataset_from_crosshair(y_datasets[0], "y")

                action.triggered.connect(set_both)
        builder.ensure_separator()

        self.channel_menu_group = QtWidgets.QActionGroup(self)
        for name in self.data_names:
            action = builder.append_action(name)
            action.setCheckable(True)
            action.setActionGroup(self.channel_menu_group)
            action.setChecked(name == self.plot.active_channel_name)
            action.triggered.connect(
                lambda *a, name=name: self.plot.activate_channel(name))
        builder.ensure_separator()

        super().build_context_menu(builder)

    def _set_dataset_from_crosshair(self, dataset, axis):
        if not self.crosshair:
            logger.warning(
                "Plot not initialised yet, ignoring set dataset request")
            return
        self.model.context.set_dataset(
            dataset,
            self.crosshair.last_x if axis == "x" else self.crosshair.last_y)
Пример #14
0
class Rolling1DPlotWidget(AlternateMenuPlotWidget):
    error = QtCore.pyqtSignal(str)
    ready = QtCore.pyqtSignal()

    def __init__(self, model: SinglePointModel, get_alternate_plot_names):
        super().__init__(get_alternate_plot_names)

        self.model = model
        self.model.channel_schemata_changed.connect(self._initialise_series)
        self.model.point_changed.connect(self._append_point)

        self.series = []
        self._history_length = 1024

        self.showGrid(x=True, y=True)

        self.source_label = add_source_id_label(
            self.getPlotItem().getViewBox(), self.model.context)

    def _initialise_series(self):
        for s in self.series:
            s.remove_items()
        self.series.clear()

        channels = self.model.get_channel_schemata()
        try:
            data_names, error_bar_names = extract_scalar_channels(channels)
        except ValueError as e:
            self.error.emit(str(e))
            return

        series_idx = 0
        axes = group_channels_into_axes(channels, data_names)
        for names in axes:
            axis, view_box = self.new_y_axis()

            info = []
            for name in names:
                color = SERIES_COLORS[series_idx % len(SERIES_COLORS)]
                data_item = pyqtgraph.ScatterPlotItem(pen=None,
                                                      brush=color,
                                                      size=6)

                error_bar_item = None
                error_bar_name = error_bar_names.get(name, None)
                if error_bar_name:
                    error_bar_item = pyqtgraph.ErrorBarItem(pen=color)

                self.series.append(
                    _Series(view_box, name, data_item, error_bar_name,
                            error_bar_item, self._history_length))

                channel = channels[name]
                label = channel["description"]
                if not label:
                    label = channel["path"].split("/")[-1]
                info.append((label, channel["path"], color, channel))

                series_idx += 1

            setup_axis_item(axis, info)

        self.ready.emit()

    def _append_point(self, point):
        for s in self.series:
            s.append(point)

    def set_history_length(self, n):
        self._history_length = n
        for s in self.series:
            s.set_history_length(n)

    def build_context_menu(self, builder):
        if self.model.context.is_online_master():
            # If no new data points are coming in, setting the history size wouldn't do
            # anything.
            # TODO: is_online_master() should really be something like
            # SinglePointModel.ever_updates().

            num_history_box = QtWidgets.QSpinBox()
            num_history_box.setMinimum(1)
            num_history_box.setMaximum(2**16)
            num_history_box.setValue(self._history_length)
            num_history_box.valueChanged.connect(self.set_history_length)

            container = QtWidgets.QWidget()

            layout = QtWidgets.QHBoxLayout()
            container.setLayout(layout)

            label = QtWidgets.QLabel("N: ")
            layout.addWidget(label)

            layout.addWidget(num_history_box)

            action = builder.append_widget_action()
            action.setDefaultWidget(container)
        builder.ensure_separator()
        super().build_context_menu(builder)
Пример #15
0
class ScanModel(Model):
    points_rewritten = QtCore.pyqtSignal(dict)
    points_appended = QtCore.pyqtSignal(dict)
    annotations_changed = QtCore.pyqtSignal(list)

    def __init__(self, axes: List[Dict[str, Any]], schema_revision: int,
                 context: Context):
        super().__init__(schema_revision, context)
        self.axes = axes
        self._annotations = []
        self._annotation_schemata = []
        self._online_analyses = {}

    def get_point_data(self) -> Dict[str, Any]:
        raise NotImplementedError

    def get_annotations(self) -> List[Annotation]:
        return self._annotations

    def get_analysis_result_source(
            self, name: str) -> Optional[AnnotationDataSource]:
        raise NotImplementedError

    #
    # TODO: Having these as elaborate implementation in the base class leaves a bit of a
    # bad aftertaste, although it's slightly hard to qualify why it should be bad
    # design.
    #

    def _set_annotation_schemata(self, schemata: List[Dict[str, Any]]):
        """Replace annotations with ones created according to the given schemata.

        This will be called by concrete subclasses once/whenever they have received the
        annotation metadata.
        """
        self._annotation_schemata = schemata
        self._annotations = []

        def data_source(spec):
            kind = spec["kind"]
            if kind == "fixed":
                return FixedDataSource(spec["value"])

            # `online_result` was called `analysis_result` prior to revision 2, with
            # identical semantics; analysis results proper didn't exit.
            if kind == "online_result" or (self.schema_revision < 2
                                           and kind == "analysis_result"):
                analysis = self._online_analyses.get(spec["analysis_name"],
                                                     None)
                if analysis is None:
                    return None
                return OnlineAnalysisDataSource(analysis, spec["result_key"])
            if kind == "analysis_result":
                name = spec["name"]
                source = self.get_analysis_result_source(name)
                if source is None:
                    logger.info("Analysis result data source not found: %s",
                                name)
                return source

            logger.info(
                "Ignoring unsupported annotation data source type: '%s'", kind)
            return None

        def to_data_sources(specs):
            return {k: data_source(v) for k, v in specs.items()}

        for schema in schemata:
            sources = [
                to_data_sources(schema.get(n)) for n in ("coordinates", "data")
            ]
            if any(s is None for t in sources for s in t.values()):
                logger.warning("Ignoring analysis, not all data found: %s",
                               schema)
                continue
            self._annotations.append(
                Annotation(schema["kind"], schema.get("parameters", {}),
                           *sources))
        self.annotations_changed.emit(self._annotations)

    def _set_online_analyses(
            self, analysis_schemata: Dict[str, Dict[str, Any]]) -> None:
        """Create and hook up online analyses from the given schema.

        This will be called by concrete subclasses once/whenever they have received
        the schema metadata.
        """
        for a in self._online_analyses.values():
            a.stop()
        self._online_analyses = {}

        for name, schema in analysis_schemata.items():
            kind = schema["kind"]
            if kind == "named_fit":
                self._online_analyses[name] = OnlineNamedFitAnalysis(
                    schema, self)
            else:
                logger.warning(
                    "Ignoring unsupported online analysis type: '%s'", kind)

        # Rebind annotation schemata to new analysis data sources.
        self._set_annotation_schemata(self._annotation_schemata)
Пример #16
0
class XY1DPlotWidget(pyqtgraph.PlotWidget):
    error = QtCore.pyqtSignal(str)

    def __init__(self, x_schema, set_dataset):
        super().__init__()

        self.set_dataset = set_dataset

        self.series_initialised = False
        self.series = []

        path = x_schema["path"]
        if not path:
            path = "/"
        identity_string = x_schema["param"]["fqn"] + "@" + path
        self.x_unit_suffix, self.x_data_to_display_scale = setup_axis_item(
            self.getAxis("bottom"), x_schema["param"]["description"],
            identity_string, x_schema["param"]["spec"])

        self._install_context_menu(x_schema)
        self.crosshair = None
        self.showGrid(x=True, y=True)

    def data_changed(self, data, mods):
        def d(name):
            return data.get("ndscan." + name, (False, None))[1]

        if not self.series_initialised:
            channels_json = d("channels")
            if not channels_json:
                return

            channels = json.loads(channels_json)

            try:
                data_names, error_bar_names = extract_scalar_channels(channels)
            except ValueError as e:
                self.error.emit(str(e))

            # KLUDGE: We rely on fit specs to be set before channels in order
            # for them to be displayed at all.
            fit_specs = json.loads(d("auto_fit") or "[]")

            for i, name in enumerate(data_names):
                color = SERIES_COLORS[i % len(SERIES_COLORS)]
                data_item = pyqtgraph.ScatterPlotItem(pen=None,
                                                      brush=color,
                                                      size=5)

                error_bar_name = error_bar_names.get(name, None)
                error_bar_item = pyqtgraph.ErrorBarItem(
                    pen=color) if error_bar_name else None

                # TODO: Multiple fit specs, error bars from other channels.
                fit_spec = None
                fit_item = None
                fit_pois = []
                for spec in fit_specs:
                    if spec["data"]["x"] != "axis_0":
                        continue
                    if spec["data"]["y"] != "channel_" + name:
                        continue
                    e = spec["data"].get("y_err", None)
                    if e and e != ("channel_" + error_bar_name):
                        continue

                    fit_spec = spec
                    fit_color = FIT_COLORS[i % len(FIT_COLORS)]
                    pen = pyqtgraph.mkPen(fit_color, width=3)
                    fit_item = pyqtgraph.PlotCurveItem(pen=pen)

                    for p in spec.get("pois", []):
                        # TODO: Support horizontal lines, points, ...
                        if p.get("x", None):
                            fit_pois.append(
                                _VLineFitPOI(p["x"], fit_color,
                                             self.x_data_to_display_scale,
                                             self.x_unit_suffix))
                    break

                self.series.append(
                    _XYSeries(self, name, data_item, error_bar_name,
                              error_bar_item, False, fit_spec, fit_item,
                              fit_pois))

            if len(data_names) == 1:
                # If there is only one series, set label/scaling accordingly.
                # TODO: Add multiple y axis for additional channels.
                c = channels[data_names[0]]

                label = c["description"]
                if not label:
                    label = c["path"].split("/")[-1]

                # TODO: Change result channel schema and move properties accessed here
                # into "spec" field to match parameters?
                self.y_unit_suffix, self.y_data_to_display_scale = setup_axis_item(
                    self.getAxis("left"), label, c["path"], c)
            else:
                self.y_unit_suffix = ""
                self.y_data_to_display_scale = 1.0

            self.crosshair = LabeledCrosshairCursor(
                self, self, self.x_unit_suffix, self.x_data_to_display_scale,
                self.y_unit_suffix, self.y_data_to_display_scale)
            self.series_initialised = True

        x_data = d("points.axis_0")
        if not x_data:
            return

        for s in self.series:
            s.update(x_data, data)

    def _install_context_menu(self, x_schema):
        entries = []

        for d in extract_linked_datasets(x_schema["param"]):
            action = QtWidgets.QAction("Set '{}' from crosshair".format(d),
                                       self)
            action.triggered.connect(
                lambda: self._set_dataset_from_crosshair_x(d))
            entries.append(action)

        if entries:
            separator = QtWidgets.QAction("", self)
            separator.setSeparator(True)
            entries.append(separator)

        self.plotItem.getContextMenus = lambda ev: entries

    def _set_dataset_from_crosshair_x(self, dataset):
        if not self.crosshair:
            logger.warning(
                "Plot not initialised yet, ignoring set dataset request")
            return
        self.set_dataset(dataset, self.crosshair.last_x)
Пример #17
0
class Root(QtCore.QObject):
    model_changed = QtCore.pyqtSignal(object)

    def get_model(self) -> Union["Model", None]:
        raise NotImplementedError
Пример #18
0
class XY1DPlotWidget(SubplotMenuPlotWidget):
    error = QtCore.pyqtSignal(str)
    ready = QtCore.pyqtSignal()

    def __init__(self, model: ScanModel, get_alternate_plot_names):
        super().__init__(model.context, get_alternate_plot_names)

        self.model = model
        self.model.channel_schemata_changed.connect(self._initialise_series)
        self.model.points_appended.connect(self._update_points)
        self.model.annotations_changed.connect(self._update_annotations)

        # FIXME: Just re-set values instead of throwing away everything.
        def rewritten(points):
            self._initialise_series(self.model.get_channel_schemata())
            self._update_points(points)

        self.model.points_rewritten.connect(rewritten)

        self.selected_point_model = SelectPointFromScanModel(self.model)
        self.subscan_roots = {}

        self.annotation_items = []
        self.series = []

        x_schema = self.model.axes[0]
        path = x_schema["path"]
        if not path:
            path = "/"
        identity_string = x_schema["param"]["fqn"] + "@" + path
        self.x_unit_suffix, self.x_data_to_display_scale = setup_axis_item(
            self.getAxis("bottom"), [(x_schema["param"]["description"], identity_string,
                                      None, x_schema["param"]["spec"])])
        self.crosshair = None
        self._highlighted_spot = None
        self.showGrid(x=True, y=True)

        self.getPlotItem().getViewBox().scene().sigMouseClicked.connect(
            self._handle_scene_click)

    def _initialise_series(self, channels):
        for s in self.series:
            s.remove_items()
        self.series.clear()

        try:
            data_names, error_bar_names = extract_scalar_channels(channels)
        except ValueError as e:
            self.error.emit(str(e))
            return

        colors = [SERIES_COLORS[i % len(SERIES_COLORS)] for i in range(len(data_names))]
        for i, (name, color) in enumerate(zip(data_names, colors)):
            data_item = pyqtgraph.ScatterPlotItem(pen=None, brush=color, size=6)
            data_item.sigClicked.connect(self._point_clicked)

            error_bar_name = error_bar_names.get(name, None)
            error_bar_item = pyqtgraph.ErrorBarItem(
                pen=color) if error_bar_name else None

            self.series.append(
                _XYSeries(self, name, data_item, error_bar_name, error_bar_item, False))

        # If there is only one series, set unit/scale accordingly.
        # TODO: Add multiple y axes for additional channels.
        def axis_info(i):
            c = channels[data_names[i]]
            label = c["description"]
            if not label:
                label = c["path"].split("/")[-1]
            return label, c["path"], colors[i], c

        self.y_unit_suffix, self.y_data_to_display_scale = setup_axis_item(
            self.getAxis("left"), [axis_info(i) for i in range(len(data_names))])

        if self.crosshair is None:
            # FIXME: Reinitialise crosshair as necessary on schema changes.
            self.crosshair = LabeledCrosshairCursor(
                self, self.getPlotItem(), self.x_unit_suffix,
                self.x_data_to_display_scale, self.y_unit_suffix,
                self.y_data_to_display_scale)
        self.subscan_roots = create_subscan_roots(self.selected_point_model)
        self.ready.emit()

    def _update_points(self, points):
        x_data = points["axis_0"]
        # Compare length to zero instead of using `not x_data` for NumPy array
        # compatibility.
        if len(x_data) == 0:
            return

        for s in self.series:
            s.update(x_data, points)

    def _update_annotations(self):
        for item in self.annotation_items:
            item.remove()
        self.annotation_items.clear()

        def series_idx(ref):
            for i, s in enumerate(self.series):
                if "channel_" + s.data_name == ref:
                    return i
            return 0

        def make_curve_item(series_idx):
            color = FIT_COLORS[series_idx % len(FIT_COLORS)]
            pen = pyqtgraph.mkPen(color, width=3)
            return pyqtgraph.PlotCurveItem(pen=pen)

        annotations = self.model.get_annotations()
        for a in annotations:
            if a.kind == "location":
                if set(a.coordinates.keys()) == set(["axis_0"]):
                    idx = max(
                        series_idx(chan)
                        for chan in a.parameters.get("associated_channels", [None]))
                    color = FIT_COLORS[idx % len(FIT_COLORS)]
                    line = VLineItem(a.coordinates["axis_0"],
                                     a.data.get("axis_0_error",
                                                None), self.getPlotItem(), color,
                                     self.x_data_to_display_scale, self.x_unit_suffix)
                    self.annotation_items.append(line)
                    continue

            if a.kind == "curve":
                idx = None
                for i, s in enumerate(self.series):
                    match_coords = set(["axis_0", "channel_" + s.data_name])
                    if set(a.coordinates.keys()) == match_coords:
                        idx = i
                        break
                if idx is not None:
                    curve = make_curve_item(idx)
                    item = CurveItem(a.coordinates["axis_0"],
                                     a.coordinates["channel_" + s.data_name],
                                     self.getPlotItem(), curve)
                    self.annotation_items.append(item)
                    continue

            if a.kind == "computed_curve":
                function_name = a.parameters.get("function_name", None)
                if ComputedCurveItem.is_function_supported(function_name):
                    idx = max(
                        series_idx(chan)
                        for chan in a.parameters.get("associated_channels", []))

                    curve = make_curve_item(idx)
                    item = ComputedCurveItem(function_name, a.data, self.getPlotItem(),
                                             curve)
                    self.annotation_items.append(item)
                    continue

            logger.info("Ignoring annotation of kind '%s' with coordinates %s", a.kind,
                        list(a.coordinates.keys()))

    def build_context_menu(self, builder):
        x_schema = self.model.axes[0]

        if self.model.context.is_online_master():
            for d in extract_linked_datasets(x_schema["param"]):
                action = builder.append_action("Set '{}' from crosshair".format(d))
                action.triggered.connect(lambda: self._set_dataset_from_crosshair_x(d))

        builder.ensure_separator()
        super().build_context_menu(builder)

    def _set_dataset_from_crosshair_x(self, dataset_key):
        if not self.crosshair:
            logger.warning("Plot not initialised yet, ignoring set dataset request")
            return
        self.model.context.set_dataset(dataset_key, self.crosshair.last_x)

    def _highlight_spot(self, spot):
        if self._highlighted_spot is not None:
            self._highlighted_spot.resetPen()
            self._highlighted_spot = None
        if spot is not None:
            spot.setPen("y", width=2)
            self._highlighted_spot = spot

    def _point_clicked(self, scatter_plot_item, spot_items):
        if not spot_items:
            # No points clicked – events don't seem to emitted in this case anyway.
            self._background_clicked()
            return

        # Arbitrarily choose the first element in the list if multiple spots
        # overlap; the user can always zoom in if that is undesired.
        spot = spot_items[0]
        self._highlight_spot(spot)
        self.selected_point_model.set_source_index(spot.index())

    def _background_clicked(self):
        self._highlight_spot(None)
        self.selected_point_model.set_source_index(None)

    def _handle_scene_click(self, event):
        if not event.isAccepted():
            # Event not handled yet, so background/… was clicked instead of a point.
            self._background_clicked()
Пример #19
0
class OnlineAnalysis(QtCore.QObject):
    updated = QtCore.pyqtSignal()

    def stop(self):
        pass
Пример #20
0
class OnlineNamedFitAnalysis(OnlineAnalysis):
    _trigger_recompute_fit = QtCore.pyqtSignal()

    def __init__(self, schema: Dict[str, Any], parent_model):
        super().__init__()
        self._schema = schema
        self._model = parent_model

        self._fit_type = self._schema["fit_type"]
        self._fit_obj = FIT_OBJECTS[self._fit_type]

        self._last_fit_params = None
        self._last_fit_errors = None

        self._recompute_fit_limiter = SignalProxy(
            self._trigger_recompute_fit,
            slot=lambda: asyncio.ensure_future(self._recompute_fit()),
            rateLimit=30)
        self._recompute_in_progress = False
        self._fit_executor = ProcessPoolExecutor(max_workers=1)

        self._model.points_rewritten.connect(self._update)
        self._model.points_appended.connect(self._update)

        self._update()

    def stop(self):
        self._model.points_rewritten.disconnect(self._update)
        self._model.points_appended.disconnect(self._update)
        self._fit_executor.shutdown(wait=False)

    def get_data(self):
        if self._last_fit_params is None:
            return {}
        result = self._last_fit_params.copy()
        for key, value in self._last_fit_errors.items():
            error_key = key + "_error"
            if error_key in result:
                raise ValueError(
                    "Fit error key name collides with result: ''".format(
                        error_key))
            result[error_key] = value
        return result

    def _update(self):
        data = self._model.get_point_data()

        self._source_data = {}
        for param_key, source_key in self._schema["data"].items():
            self._source_data[param_key] = data.get(source_key, [])

        num_points = min(len(v) for v in self._source_data.values())
        if num_points < len(self._fit_obj.parameter_names):
            return

        for key, value in self._source_data.items():
            self._source_data[key] = value[:num_points]
        self._trigger_recompute_fit.emit()

    async def _recompute_fit(self):
        if self._recompute_in_progress:
            # Run at most one fit computation at a time. To make sure we don't
            # leave a few final data points completely disregarded, just
            # re-emit the signal – even for long fits, repeated checks aren't
            # expensive, as long as the SignalProxy rate is slow enough.
            self._trigger_recompute_fit.emit()
            return

        self._recompute_in_progress = True

        # oitg.fitting currently only supports 1D fits, but this could/should be
        # changed.
        xs = self._source_data["x"]
        ys = self._source_data["y"]
        y_errs = self._source_data.get("y_err", None)

        loop = asyncio.get_event_loop()
        self._last_fit_params, self._last_fit_errors = await loop.run_in_executor(
            self._fit_executor, _run_fit, self._fit_type, xs, ys, y_errs)

        self._recompute_in_progress = False
        self.updated.emit()
Пример #21
0
class Rolling1DPlotWidget(AlternateMenuPlotWidget):
    error = QtCore.pyqtSignal(str)
    ready = QtCore.pyqtSignal()
    alternate_plot_requested = QtCore.pyqtSignal(str)

    def __init__(self, model: SinglePointModel, get_alternate_plot_names):
        super().__init__(get_alternate_plot_names)

        self.model = model
        self.model.channel_schemata_changed.connect(self._initialise_series)
        self.model.point_changed.connect(self._append_point)

        self.series = []
        self._history_length = 1024

        self.showGrid(x=True, y=True)

    def _initialise_series(self):
        for s in self.series:
            s.remove_items()
        self.series.clear()

        channels = self.model.get_channel_schemata()
        try:
            data_names, error_bar_names = extract_scalar_channels(channels)
        except ValueError as e:
            self.error.emit(str(e))
            return

        colors = [
            SERIES_COLORS[i % len(SERIES_COLORS)]
            for i in range(len(data_names))
        ]
        for i, (data_name, color) in enumerate(zip(data_names, colors)):
            data_item = pyqtgraph.ScatterPlotItem(pen=None, brush=color)

            error_bar_name = error_bar_names.get(data_name, None)
            error_bar_item = pyqtgraph.ErrorBarItem(
                pen=color) if error_bar_name else None

            self.series.append(
                _Series(self, data_name, data_item, error_bar_name,
                        error_bar_item, self._history_length))

        def axis_info(i):
            # If there is only one series, set label/scaling accordingly.
            # TODO: Add multiple y axis for additional channels.
            c = channels[data_names[i]]
            label = c["description"]
            if not label:
                label = c["path"].split("/")[-1]
            return label, c["path"], colors[i], c

        setup_axis_item(self.getAxis("left"),
                        [axis_info(i) for i in range(len(data_names))])

        self.ready.emit()

    def _append_point(self, point):
        for s in self.series:
            s.append(point)

    def set_history_length(self, n):
        self._history_length = n
        for s in self.series:
            s.set_history_length(n)

    def build_context_menu(self, builder):
        if self.model.context.is_online_master():
            # If no new data points are coming in, setting the history size wouldn't do
            # anything.
            # TODO: is_online_master() should really be something like
            # SinglePointModel.ever_updates().

            num_history_box = QtWidgets.QSpinBox()
            num_history_box.setMinimum(1)
            num_history_box.setMaximum(2**16)
            num_history_box.setValue(self._history_length)
            num_history_box.valueChanged.connect(self.set_history_length)

            container = QtWidgets.QWidget()

            layout = QtWidgets.QHBoxLayout()
            container.setLayout(layout)

            label = QtWidgets.QLabel("N: ")
            layout.addWidget(label)

            layout.addWidget(num_history_box)

            action = builder.append_widget_action()
            action.setDefaultWidget(container)
        builder.ensure_separator()
        super().build_context_menu(builder)
Пример #22
0
class OnlineNamedFitAnalysis(OnlineAnalysis):
    """Implements :class:`ndscan.experiment.default_analysis.OnlineFit`, that is, a fit
    of a well-known function that is executed repeatedly as new data is coming in.

    :param schema: The ``ndscan.online_analyses`` schema to implement.
    :param parent_model: The :class:`~ndscan.plots.model.ScanModel` to draw the data
        from. The schema is notexpected not to change until :meth:`stop` is called.
    """
    _trigger_recompute_fit = QtCore.pyqtSignal()

    def __init__(self, schema: Dict[str, Any], parent_model):
        super().__init__()
        self._schema = schema
        self._model = parent_model

        self._fit_type = self._schema["fit_type"]
        self._fit_obj = FIT_OBJECTS[self._fit_type]
        self._constants = self._schema.get("constants", {})
        self._initial_values = self._schema.get("initial_values", {})

        self._last_fit_params = None
        self._last_fit_errors = None

        self._recompute_fit_limiter = SignalProxy(
            self._trigger_recompute_fit,
            slot=lambda: asyncio.ensure_future(self._recompute_fit()),
            rateLimit=30)
        self._recompute_in_progress = False
        self._fit_executor = ProcessPoolExecutor(max_workers=1)

        self._model.points_rewritten.connect(self._update)
        self._model.points_appended.connect(self._update)

        self._update()

    def stop(self):
        self._model.points_rewritten.disconnect(self._update)
        self._model.points_appended.disconnect(self._update)
        self._fit_executor.shutdown(wait=False)

    def get_data(self):
        if self._last_fit_params is None:
            return {}
        result = self._last_fit_params.copy()
        for key, value in self._last_fit_errors.items():
            error_key = key + "_error"
            if error_key in result:
                raise ValueError(
                    "Fit error key name collides with result: '{}'".format(
                        error_key))
            result[error_key] = value
        return result

    def _update(self):
        data = self._model.get_point_data()

        self._source_data = {}
        for param_key, source_key in self._schema["data"].items():
            self._source_data[param_key] = data.get(source_key, [])

        # Truncate the source data to a complete set of points.
        num_points = min(len(v) for v in self._source_data.values())
        if num_points < len(self._fit_obj.parameter_names):
            # Not enough points yet for the given number of degrees of freedom.
            return

        for key, value in self._source_data.items():
            self._source_data[key] = value[:num_points]
        self._trigger_recompute_fit.emit()

    async def _recompute_fit(self):
        if self._recompute_in_progress:
            # Run at most one fit computation at a time. To make sure we don't
            # leave a few final data points completely disregarded, just
            # re-emit the signal – even for long fits, repeated checks aren't
            # expensive, as long as the SignalProxy rate is slow enough.
            self._trigger_recompute_fit.emit()
            return

        self._recompute_in_progress = True

        # oitg.fitting currently only supports 1D fits, but this could/should be
        # changed.
        xs = self._source_data["x"]
        ys = self._source_data["y"]
        y_errs = self._source_data.get("y_err", None)

        loop = asyncio.get_event_loop()
        self._last_fit_params, self._last_fit_errors = await loop.run_in_executor(
            self._fit_executor, _run_fit, self._fit_type, xs, ys, y_errs,
            self._constants, self._initial_values)

        self._recompute_in_progress = False
        self.updated.emit()