Beispiel #1
0
class ReferenceWidget:
    "Dropdown for choosing the reference"
    __files: FileList
    __theme: ReferenceWidgetTheme
    __widget: Dropdown

    def __init__(self, ctrl, model) -> None:
        self.__theme = ctrl.theme.swapmodels(ReferenceWidgetTheme())
        self.__model = model
        self.__files = FileList(ctrl)

    def addtodoc(self, mainview, ctrl, *_) -> List[Widget]:
        "creates the widget"

        self.__widget = Dropdown(name='HS:reference',
                                 width=self.__theme.width,
                                 height=self.__theme.height,
                                 **self.__data())

        @mainview.actionifactive(ctrl)
        def _py_cb(new):
            inew = int(new.item)
            val = None if inew < 0 else [i for _, i in self.__files()][inew]
            self.__model.fittoreference.reference = val

        def _observe(old=None, **_):
            if 'reference' in old and mainview.isactive():
                data = self.__data()
                mainview.calllater(lambda: self.__widget.update(**data))

        ctrl.display.observe(FitToReferenceStore().name, _observe)

        self.__widget.on_click(_py_cb)
        return [self.__widget]

    def reset(self, resets):
        "updates the widget"
        resets[self.__widget].update(**self.__data())

    @property
    def widget(self):
        "returns the widget"
        return self.__widget

    def __data(self) -> dict:
        lst = list(self.__files())
        menu: list = [(j, str(i)) for i, j in enumerate(i for i, _ in lst)]
        menu += [None, (self.__theme.title, '-1')]

        key = self.__model.fittoreference.reference
        index = -1 if key is None else [i for _, i in lst].index(key)
        return dict(menu=menu, label=menu[index][0], value=str(index))
Beispiel #2
0
    menu.append((key,key))



plot=p.circle(x='xdata', y='ydata', source= source, size='size',color='ColorList', alpha=0.6, line_color="black", line_width = 1)
ds = plot.data_source

dropdownX = Dropdown(label="X value", button_type="default", menu=menu)
def dropdownX_handler(new):
    dictionary['xdata'] = dictionary[new.item]
    ds.data = dictionary
    print(dictionary['xdata'])
    print(new.item)
    p.xaxis.axis_label = new.item

dropdownX.on_click(dropdownX_handler)


dropdownY = Dropdown(label="Y value", button_type="default", menu=menu)
def dropdownY_handler(new):
    dictionary['ydata'] = dictionary[new.item]
    ds.data = dictionary
    p.yaxis.axis_label = new.item

dropdownY.on_click(dropdownY_handler)

dropdownR = Dropdown(label="Radius", button_type="default", menu=menu)
def dropdownR_handler(new):
    temp = dictionary[new.item]
    sizelist = [float(i - min(temp)) / float(max(temp) - min(temp)) * 29 + 5 for i in temp]
    dictionary['size'] = sizelist
Beispiel #3
0
class SupervisableDataset(Loggable):
    """
    Feature-agnostic class for a dataset open to supervision.
    Raw -- piecewise annoatation -> Gold -> Dev/Test
    Raw -- batch annotation -> Noisy -> Train

    Keeping a DataFrame form and a list-of-dicts ("dictl") form, with the intention that
    - the DataFrame form supports most kinds of operations;
    - the list-of-dicts form could be useful for manipulations outside the scope of pandas;
    - synchronization between the two forms should be called sparingly.
    """

    # 'scratch': intended to be directly editable by other objects, i.e. Explorers
    # labels will be stored but not used for information in hover itself
    SCRATCH_SUBSETS = tuple(["raw"])

    # non-'scratch': intended to be read-only outside of the class
    # 'public': labels will be considered as part of the classification task and will be used for built-in supervision
    PUBLIC_SUBSETS = tuple(["train", "dev"])
    # 'private': labels will be considered as part of the classification task and will NOT be used for supervision
    PRIVATE_SUBSETS = tuple(["test"])

    FEATURE_KEY = "feature"

    def __init__(
        self,
        raw_dictl,
        train_dictl=None,
        dev_dictl=None,
        test_dictl=None,
        feature_key="feature",
        label_key="label",
    ):
        """
        Initialize the dataset with dictl and df forms; initialize the mapping between categorical-int and string labels.

        - param raw_dictl: a list of dicts holding the raw data that DO NOT have annotation.
        - param train_dictl: a list of dicts holding the batch-annotated noisy train set.
        - param dev_dictl: a list of dicts holding the gold dev set.
        - param test_dictl: a list of dicts holding the gold test set.
        - param feature_key: key in each piece of dict mapping to the feature.
        - param label_key: key in each piece of dict mapping to the ground truth in STRING form.
        """
        self._info("Initializing...")

        def dictl_transform(dictl, labels=True):
            """
            Burner function to transform the input list of dictionaries into standard format.
            """
            # edge case when dictl is empty or None
            if not dictl:
                return []

            # transform the feature and possibly the label
            key_transform = {feature_key: self.__class__.FEATURE_KEY}
            if labels:
                key_transform[label_key] = "label"

            def burner(d):
                """
                Burner function to transform a single dict.
                """
                if labels:
                    assert label_key in d, f"Expected dict key {label_key}"

                trans_d = {
                    key_transform.get(_k, _k): _v
                    for _k, _v in d.items()
                }

                if not labels:
                    trans_d["label"] = module_config.ABSTAIN_DECODED

                return trans_d

            return [burner(_d) for _d in dictl]

        self.dictls = {
            "raw": dictl_transform(raw_dictl, labels=False),
            "train": dictl_transform(train_dictl),
            "dev": dictl_transform(dev_dictl),
            "test": dictl_transform(test_dictl),
        }

        self.synchronize_dictl_to_df()
        self.df_deduplicate()
        self.synchronize_df_to_dictl()
        self.setup_widgets()
        # self.setup_label_coding() # redundant if setup_pop_table() immediately calls this again
        self.setup_pop_table(width_policy="fit", height_policy="fit")
        self._good("Finished initialization.")

    def copy(self, use_df=True):
        """
        Create another instance, carrying over the data entries.
        """
        if use_df:
            self.synchronize_df_to_dictl()
        return self.__class__(
            raw_dictl=self.dictls["raw"],
            train_dictl=self.dictls["train"],
            dev_dictl=self.dictls["dev"],
            test_dictl=self.dictls["test"],
            feature_key=self.__class__.FEATURE_KEY,
            label_key="label",
        )

    def setup_widgets(self):
        """
        Critical widgets for interactive data management.
        """
        self.update_pusher = Button(label="Push",
                                    button_type="success",
                                    height_policy="fit",
                                    width_policy="min")
        self.data_committer = Dropdown(
            label="Commit",
            button_type="warning",
            menu=[
                *self.__class__.PUBLIC_SUBSETS, *self.__class__.PRIVATE_SUBSETS
            ],
            height_policy="fit",
            width_policy="min",
        )
        self.dedup_trigger = Button(
            label="Dedup",
            button_type="warning",
            height_policy="fit",
            width_policy="min",
        )

        def commit_base_callback():
            """
            COMMIT creates cross-duplicates between subsets.

            - PUSH shall be blocked until DEDUP is executed.
            """
            self.dedup_trigger.disabled = False
            self.update_pusher.disabled = True

        def dedup_base_callback():
            """
            DEDUP re-creates dfs with different indices than before.

            - COMMIT shall be blocked until PUSH is executed.
            """
            self.update_pusher.disabled = False
            self.data_committer.disabled = True
            self.df_deduplicate()

        def push_base_callback():
            """
            PUSH enforces df consistency with all linked explorers.

            - DEDUP could be blocked because it stays trivial until COMMIT is executed.
            """
            self.data_committer.disabled = False
            self.dedup_trigger.disabled = True

        self.update_pusher.on_click(push_base_callback)
        self.data_committer.on_click(commit_base_callback)
        self.dedup_trigger.on_click(dedup_base_callback)

        self.help_div = dataset_help_widget()

    def view(self):
        """
        Defines the layout of bokeh models.
        """
        # local import to avoid naming confusion/conflicts
        from bokeh.layouts import row, column

        return column(
            self.help_div,
            row(self.update_pusher, self.data_committer, self.dedup_trigger),
            self.pop_table,
        )

    def subscribe_update_push(self, explorer, subset_mapping):
        """
        Enable pushing updated DataFrames to explorers that depend on them.

        Note: the reason we need this is due to `self.dfs[key] = ...`-like assignments. If DF operations were all in-place, then the explorers could directly access the updates through their `self.dfs` references.
        """
        # local import to avoid import cycles
        from hover.core.explorer import BokehForLabeledText

        assert isinstance(explorer, BokehForLabeledText)

        def callback_push():
            df_dict = {_v: self.dfs[_k] for _k, _v in subset_mapping.items()}
            explorer._setup_dfs(df_dict)
            explorer._update_sources()

        self.update_pusher.on_click(callback_push)
        self._good(
            f"Subscribed {explorer.__class__.__name__} to dataset pushes: {subset_mapping}"
        )

    def subscribe_data_commit(self, explorer, subset_mapping):
        """
        Enable committing data across subsets, specified by a selection in an explorer and a dropdown widget of the dataset.
        """
        def callback_commit(event):
            for sub_k, sub_v in subset_mapping.items():
                sub_to = event.item
                selected_idx = explorer.sources[sub_v].selected.indices
                if not selected_idx:
                    self._warn(
                        f"Attempting data commit: did not select any data points in subset {sub_v}."
                    )
                    return

                # take selected slice, ignoring ABSTAIN'ed rows
                sel_slice = self.dfs[sub_k].iloc[selected_idx]
                valid_slice = sel_slice[
                    sel_slice["label"] != module_config.ABSTAIN_DECODED]

                # concat to the end and do some accounting
                size_before = self.dfs[sub_to].shape[0]
                self.dfs[sub_to] = pd.concat(
                    [self.dfs[sub_to], valid_slice],
                    axis=0,
                    sort=False,
                    ignore_index=True,
                )
                size_mid = self.dfs[sub_to].shape[0]
                self.dfs[sub_to].drop_duplicates(
                    subset=[self.__class__.FEATURE_KEY],
                    keep="last",
                    inplace=True)
                size_after = self.dfs[sub_to].shape[0]

                self._info(
                    f"Committed {valid_slice.shape[0]} (valid out of {sel_slice.shape[0]} selected) entries from {sub_k} to {sub_to} ({size_before} -> {size_after} with {size_mid-size_after} overwrites)."
                )

        self.data_committer.on_click(callback_commit)
        self._good(
            f"Subscribed {explorer.__class__.__name__} to dataset commits: {subset_mapping}"
        )

    def setup_label_coding(self, verbose=True, debug=False):
        """
        Auto-determine labels in the dataset, then create encoder/decoder in lexical order.
        Add ABSTAIN as a no-label placeholder.
        """
        all_labels = set()
        for _key in [
                *self.__class__.PUBLIC_SUBSETS, *self.__class__.PRIVATE_SUBSETS
        ]:
            _df = self.dfs[_key]
            _found_labels = set(_df["label"].tolist())
            all_labels = all_labels.union(_found_labels)

        # exclude ABSTAIN from self.classes, but include it in the encoding
        all_labels.discard(module_config.ABSTAIN_DECODED)
        self.classes = sorted(all_labels)
        self.label_encoder = {
            **{_label: _i
               for _i, _label in enumerate(self.classes)},
            module_config.ABSTAIN_DECODED: module_config.ABSTAIN_ENCODED,
        }
        self.label_decoder = {_v: _k for _k, _v in self.label_encoder.items()}

        if verbose:
            self._good(
                f"Set up label encoder/decoder with {len(self.classes)} classes."
            )
        if debug:
            self.validate_labels()

    def validate_labels(self, raise_exception=True):
        """
        Check that every label is in the encoder.
        """
        for _key in [
                *self.__class__.PUBLIC_SUBSETS, *self.__class__.PRIVATE_SUBSETS
        ]:
            _invalid_indices = None
            assert "label" in self.dfs[_key].columns
            _mask = self.dfs[_key]["label"].apply(
                lambda x: x in self.label_encoder)
            _invalid_indices = np.where(_mask == False)[0].tolist()
            if _invalid_indices:
                self._fail(f"Subset {_key} has invalid labels:")
                self._print({self.dfs[_key].loc[_invalid_indices]})
                if raise_exception:
                    raise ValueError("invalid labels")

    def setup_pop_table(self, **kwargs):
        subsets = [
            *self.__class__.SCRATCH_SUBSETS,
            *self.__class__.PUBLIC_SUBSETS,
            *self.__class__.PRIVATE_SUBSETS,
        ]
        pop_source = ColumnDataSource(dict())
        pop_columns = [
            TableColumn(field="label", title="label"),
            *[
                TableColumn(field=f"count_{_subset}", title=_subset)
                for _subset in subsets
            ],
        ]
        self.pop_table = DataTable(source=pop_source,
                                   columns=pop_columns,
                                   **kwargs)

        def update_population():
            """
            Callback function.
            """
            # make sure that the label coding is correct
            self.setup_label_coding()

            # re-compute label population
            eff_labels = [module_config.ABSTAIN_DECODED, *self.classes]
            pop_data = dict(label=eff_labels)
            for _subset in subsets:
                _subpop = self.dfs[_subset]["label"].value_counts()
                pop_data[f"count_{_subset}"] = [
                    _subpop.get(_label, 0) for _label in eff_labels
                ]

            # push results to bokeh data source
            pop_source.data = pop_data

            self._good(
                f"Pop updater: latest population with {len(self.classes)} classes"
            )

        update_population()
        self.data_committer.on_click(update_population)
        self.dedup_trigger.on_click(update_population)

    def df_deduplicate(self):
        """
        Cross-deduplicate data entries by feature between subsets.
        """
        self._info("Deduplicating...")
        # for data entry accounting
        before, after = dict(), dict()

        # deduplicating rule: entries that come LATER are of higher priority
        ordered_subsets = [
            *self.__class__.SCRATCH_SUBSETS,
            *self.__class__.PUBLIC_SUBSETS,
            *self.__class__.PRIVATE_SUBSETS,
        ]

        # keep track of which df has which columns and which rows came from which subset
        columns = dict()
        for _key in ordered_subsets:
            before[_key] = self.dfs[_key].shape[0]
            columns[_key] = self.dfs[_key].columns
            self.dfs[_key]["__subset"] = _key

        # concatenate in order and deduplicate
        overall_df = pd.concat([self.dfs[_key] for _key in ordered_subsets],
                               axis=0,
                               sort=False)
        overall_df.drop_duplicates(subset=[self.__class__.FEATURE_KEY],
                                   keep="last",
                                   inplace=True)
        overall_df.reset_index(drop=True, inplace=True)

        # cut up slices
        for _key in ordered_subsets:
            self.dfs[_key] = overall_df[
                overall_df["__subset"] == _key].reset_index(
                    drop=True, inplace=False)[columns[_key]]
            after[_key] = self.dfs[_key].shape[0]
            self._info(
                f"--subset {_key} rows: {before[_key]} -> {after[_key]}.")

    def synchronize_dictl_to_df(self):
        """
        Re-make dataframes from lists of dictionaries.
        """
        self.dfs = dict()
        for _key, _dictl in self.dictls.items():
            if _dictl:
                _df = pd.DataFrame(_dictl)
                assert self.__class__.FEATURE_KEY in _df.columns
                assert "label" in _df.columns
            else:
                _df = pd.DataFrame(
                    columns=[self.__class__.FEATURE_KEY, "label"])

            self.dfs[_key] = _df

    def synchronize_df_to_dictl(self):
        """
        Re-make lists of dictionaries from dataframes.
        """
        self.dictls = dict()
        for _key, _df in self.dfs.items():
            self.dictls[_key] = _df.to_dict(orient="records")

    def compute_2d_embedding(self, vectorizer, method, **kwargs):
        """
        Get embeddings in the xy-plane and return the reducer.
        """
        from hover.core.representation.reduction import DimensionalityReducer

        # prepare input vectors to manifold learning
        fit_subset = [
            *self.__class__.SCRATCH_SUBSETS, *self.__class__.PUBLIC_SUBSETS
        ]
        trans_subset = [*self.__class__.PRIVATE_SUBSETS]

        assert not set(fit_subset).intersection(
            set(trans_subset)), "Unexpected overlap"

        # compute vectors and keep track which where to slice the array for fitting
        feature_inp = []
        for _key in fit_subset:
            feature_inp += self.dfs[_key][self.__class__.FEATURE_KEY].tolist()
        fit_num = len(feature_inp)
        for _key in trans_subset:
            feature_inp += self.dfs[_key][self.__class__.FEATURE_KEY].tolist()
        trans_arr = np.array([vectorizer(_inp) for _inp in tqdm(feature_inp)])

        # initialize and fit manifold learning reducer using specified subarray
        self._info(
            f"Fit-transforming {method.upper()} on {fit_num} samples...")
        reducer = DimensionalityReducer(trans_arr[:fit_num])
        fit_embedding = reducer.fit_transform(method, **kwargs)

        # compute embedding of the whole dataset
        self._info(
            f"Transforming {method.upper()} on {trans_arr.shape[0]-fit_num} samples..."
        )
        trans_embedding = reducer.transform(trans_arr[fit_num:], method)

        # assign x and y coordinates to dataset
        start_idx = 0
        for _subset, _embedding in [
            (fit_subset, fit_embedding),
            (trans_subset, trans_embedding),
        ]:
            # edge case: embedding is too small
            if _embedding.shape[0] < 1:
                for _key in _subset:
                    assert (self.dfs[_key].shape[0] == 0
                            ), "Expected empty df due to empty embedding"
                continue
            for _key in _subset:
                _length = self.dfs[_key].shape[0]
                self.dfs[_key]["x"] = pd.Series(
                    _embedding[start_idx:(start_idx + _length), 0])
                self.dfs[_key]["y"] = pd.Series(
                    _embedding[start_idx:(start_idx + _length), 1])
                start_idx += _length

        return reducer

    def loader(self, key, vectorizer, batch_size=64, smoothing_coeff=0.0):
        """
        Prepare a Torch Dataloader for training or evaluation.

        - param key(str): the subset of dataset to use.
        - param vectorizer(callable): callable that turns a string into a vector.
        - param smoothing_coeff: the smoothing coeffient for soft labels.
          - type smoothing_coeff: float
        """
        # lazy import: missing torch should not break the rest of the class
        from hover.utils.torch_helper import vector_dataloader, one_hot, label_smoothing

        # take the slice that has a meaningful label
        df = self.dfs[key][
            self.dfs[key]["label"] != module_config.ABSTAIN_DECODED]

        # edge case: valid slice is too small
        if df.shape[0] < 1:
            raise ValueError(
                f"Subset {key} has too few samples ({df.shape[0]})")
        batch_size = min(batch_size, df.shape[0])

        labels = df["label"].apply(lambda x: self.label_encoder[x]).tolist()
        features = df[self.__class__.FEATURE_KEY].tolist()
        output_vectors = one_hot(labels, num_classes=len(self.classes))

        self._info(f"Preparing {key} input vectors...")
        input_vectors = [vectorizer(_f) for _f in tqdm(features)]
        if smoothing_coeff > 0.0:
            output_vectors = label_smoothing(output_vectors,
                                             coefficient=smoothing_coeff)
        self._info(f"Preparing {key} data loader...")
        loader = vector_dataloader(input_vectors,
                                   output_vectors,
                                   batch_size=batch_size)
        self._good(
            f"Prepared {key} loader consisting of {len(features)} examples with batch size {batch_size}"
        )
        return loader
Beispiel #4
0
class SequencePathWidget:
    "Dropdown for choosing a fasta file"
    _dialog: FileDialog
    _widget: Dropdown
    _theme: SequencePathTheme
    _model: SequencePlotModelAccess

    def __init__(self, ctrl, **kwa):
        self._theme = ctrl.theme.swapmodels(SequencePathTheme(**kwa))
        self._model = SequencePlotModelAccess()
        self._model.swapmodels(ctrl)

    def addtodoc(self, mainview, ctrl, *_) -> List[Widget]:
        "creates the widget"
        self._widget = Dropdown(name='Cycles:Sequence',
                                width=self._theme.width,
                                height=self._theme.height,
                                **self._data())

        mainview.differedobserver(self._data, self._widget, ctrl.theme,
                                  self._model.sequencemodel.config,
                                  ctrl.display,
                                  self._model.sequencemodel.display)

        self._widget.on_click(ctrl.action(self._onclick))
        return [self._widget]

    def observe(self, ctrl):
        "sets-up config observers"
        self._dialog = FileDialog(ctrl,
                                  storage="sequence",
                                  title=self._theme.dlgtitle,
                                  filetypes='fasta|txt|*')

    def reset(self, resets):
        "updates the widget"
        resets[self._widget].update(**self._data())

    @property
    def widget(self):
        "returns the widget"
        return self._widget

    def callbacks(self, hover: SequenceHoverMixin, tick1: SequenceTicker):
        "sets-up callbacks for the tooltips and grids"
        if hover is not None:
            jsc = CustomJS(
                code=("if(Object.keys(src.data).indexOf(cb_obj.value) > -1)"
                      "{ cb_obj.label     = cb_obj.value;"
                      "  tick1.key        = cb_obj.value;"
                      "  tick2.key        = cb_obj.value;"
                      "  src.data['text'] = src.data[cb_obj.value];"
                      "  src.change.emit(); }"),
                args=dict(tick1=tick1, tick2=tick1.axis, src=hover.source))
            self._widget.js_on_change('value', jsc)
        return self._widget

    def _data(self) -> dict:
        lst = sorted(self._model.sequences(...))
        key = self._model.sequencemodel.currentkey
        val = key if key in lst else None
        label = self._theme.missingkey if val is None else key

        menu: List[Optional[Tuple[str, str]]] = [(i, i) for i in lst]
        menu += [
            None if len(menu) else ('', '→'), (self._theme.missingpath, '←')
        ]

        return dict(menu=menu, label=label, value='→' if val is None else val)

    def _onclick(self, new):
        if new.item == '←':
            path = self._dialog.open()
            self._widget.value = '→'
            if self._model.setnewsequencepath(path):
                if path is not None:
                    raise IOError("Could not find any sequence in the file")
        elif new.item != '→':
            self._model.setnewsequencekey(new.item)
Beispiel #5
0
    lb = event.item
    y1 = cnthr[int(lb)]
    fig.line(x=x_always,
             y=y1,
             line_width=2,
             legend_label=lb,
             line_color='pink')


def handler2(event):
    lb = event.item
    y2 = cnthr[int(lb)]
    fig.line(x_always, y2, line_width=2, legend_label=lb, line_color='blue')


dd1.on_click(handler1)
dd2.on_click(handler2)
fig = figure(x_range=[
    'Jan', 'Feb', 'Mar', 'Apr', 'May', 'June', 'July', 'Aug', 'Sept', 'Oct',
    'Nov', 'Dec'
],
             y_range=(0, 1000),
             title="NYC 311 Average Response Time 2020",
             x_axis_label='month',
             y_axis_label='Average response time in hour')
fig.line(x_always,
         all_avg,
         line_width=2,
         legend_label='all',
         line_color='grey')
curdoc().add_root(column(dd1, dd2, fig))
Beispiel #6
0
def create(palm):
    fit_max = 1
    fit_min = 0

    doc = curdoc()

    # THz calibration plot
    scan_plot = Plot(
        title=Title(text="THz calibration"),
        x_range=DataRange1d(),
        y_range=DataRange1d(),
        plot_height=PLOT_CANVAS_HEIGHT,
        plot_width=PLOT_CANVAS_WIDTH,
        toolbar_location="right",
    )

    # ---- tools
    scan_plot.toolbar.logo = None
    scan_plot.add_tools(PanTool(), BoxZoomTool(), WheelZoomTool(), ResetTool())

    # ---- axes
    scan_plot.add_layout(LinearAxis(axis_label="Stage delay motor"), place="below")
    scan_plot.add_layout(
        LinearAxis(axis_label="Energy shift, eV", major_label_orientation="vertical"), place="left"
    )

    # ---- grid lines
    scan_plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    scan_plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- circle cluster glyphs
    scan_circle_source = ColumnDataSource(dict(x=[], y=[]))
    scan_plot.add_glyph(scan_circle_source, Circle(x="x", y="y", line_alpha=0, fill_alpha=0.5))

    # ---- circle glyphs
    scan_avg_circle_source = ColumnDataSource(dict(x=[], y=[]))
    scan_plot.add_glyph(
        scan_avg_circle_source, Circle(x="x", y="y", line_color="purple", fill_color="purple")
    )

    # ---- line glyphs
    fit_line_source = ColumnDataSource(dict(x=[], y=[]))
    scan_plot.add_glyph(fit_line_source, Line(x="x", y="y", line_color="purple"))

    # THz calibration folder path text input
    def path_textinput_callback(_attr, _old_value, _new_value):
        update_load_dropdown_menu()
        path_periodic_update()

    path_textinput = TextInput(
        title="THz calibration path:", value=os.path.join(os.path.expanduser("~")), width=510
    )
    path_textinput.on_change("value", path_textinput_callback)

    # THz calibration eco scans dropdown
    def scans_dropdown_callback(event):
        scans_dropdown.label = event.item

    scans_dropdown = Dropdown(label="ECO scans", button_type="default", menu=[])
    scans_dropdown.on_click(scans_dropdown_callback)

    # ---- eco scans periodic update
    def path_periodic_update():
        new_menu = []
        if os.path.isdir(path_textinput.value):
            for entry in os.scandir(path_textinput.value):
                if entry.is_file() and entry.name.endswith(".json"):
                    new_menu.append((entry.name, entry.name))
        scans_dropdown.menu = sorted(new_menu, reverse=True)

    doc.add_periodic_callback(path_periodic_update, 5000)

    # Calibrate button
    def calibrate_button_callback():
        palm.calibrate_thz(path=os.path.join(path_textinput.value, scans_dropdown.label))
        fit_max_spinner.value = np.ceil(palm.thz_calib_data.index.values.max())
        fit_min_spinner.value = np.floor(palm.thz_calib_data.index.values.min())
        update_calibration_plot()

    def update_calibration_plot():
        scan_plot.xaxis.axis_label = f"{palm.thz_motor_name}, {palm.thz_motor_unit}"

        scan_circle_source.data.update(
            x=np.repeat(
                palm.thz_calib_data.index, palm.thz_calib_data["peak_shift"].apply(len)
            ).tolist(),
            y=np.concatenate(palm.thz_calib_data["peak_shift"].values).tolist(),
        )

        scan_avg_circle_source.data.update(
            x=palm.thz_calib_data.index.tolist(), y=palm.thz_calib_data["peak_shift_mean"].tolist()
        )

        x = np.linspace(fit_min, fit_max, 100)
        y = palm.thz_slope * x + palm.thz_intersect
        fit_line_source.data.update(x=np.round(x, decimals=5), y=np.round(y, decimals=5))

        calib_const_div.text = f"""
        thz_slope = {palm.thz_slope}
        """

    calibrate_button = Button(label="Calibrate THz", button_type="default", width=250)
    calibrate_button.on_click(calibrate_button_callback)

    # THz fit maximal value text input
    def fit_max_spinner_callback(_attr, old_value, new_value):
        nonlocal fit_max
        if new_value > fit_min:
            fit_max = new_value
            palm.calibrate_thz(
                path=os.path.join(path_textinput.value, scans_dropdown.label),
                fit_range=(fit_min, fit_max),
            )
            update_calibration_plot()
        else:
            fit_max_spinner.value = old_value

    fit_max_spinner = Spinner(title="Maximal fit value:", value=fit_max, step=0.1)
    fit_max_spinner.on_change("value", fit_max_spinner_callback)

    # THz fit maximal value text input
    def fit_min_spinner_callback(_attr, old_value, new_value):
        nonlocal fit_min
        if new_value < fit_max:
            fit_min = new_value
            palm.calibrate_thz(
                path=os.path.join(path_textinput.value, scans_dropdown.label),
                fit_range=(fit_min, fit_max),
            )
            update_calibration_plot()
        else:
            fit_min_spinner.value = old_value

    fit_min_spinner = Spinner(title="Minimal fit value:", value=fit_min, step=0.1)
    fit_min_spinner.on_change("value", fit_min_spinner_callback)

    # Save calibration button
    def save_button_callback():
        palm.save_thz_calib(path=path_textinput.value)
        update_load_dropdown_menu()

    save_button = Button(label="Save", button_type="default", width=250)
    save_button.on_click(save_button_callback)

    # Load calibration button
    def load_dropdown_callback(event):
        new_value = event.item
        palm.load_thz_calib(os.path.join(path_textinput.value, new_value))
        update_calibration_plot()

    def update_load_dropdown_menu():
        new_menu = []
        calib_file_ext = ".palm_thz"
        if os.path.isdir(path_textinput.value):
            for entry in os.scandir(path_textinput.value):
                if entry.is_file() and entry.name.endswith((calib_file_ext)):
                    new_menu.append((entry.name[: -len(calib_file_ext)], entry.name))
            load_dropdown.button_type = "default"
            load_dropdown.menu = sorted(new_menu, reverse=True)
        else:
            load_dropdown.button_type = "danger"
            load_dropdown.menu = new_menu

    doc.add_next_tick_callback(update_load_dropdown_menu)
    doc.add_periodic_callback(update_load_dropdown_menu, 5000)

    load_dropdown = Dropdown(label="Load", menu=[], width=250)
    load_dropdown.on_click(load_dropdown_callback)

    # Calibration constants
    calib_const_div = Div(
        text=f"""
        thz_slope = {0}
        """
    )

    # assemble
    tab_layout = column(
        row(
            scan_plot,
            Spacer(width=30),
            column(
                path_textinput,
                scans_dropdown,
                calibrate_button,
                fit_max_spinner,
                fit_min_spinner,
                row(save_button, load_dropdown),
                calib_const_div,
            ),
        )
    )

    return Panel(child=tab_layout, title="THz Calibration")
Beispiel #7
0
u_input = TextInput(value=odesystem_settings.sample_system_functions[
    odesystem_settings.init_fun_key][0],
                    title="u(x,y):")
v_input = TextInput(value=odesystem_settings.sample_system_functions[
    odesystem_settings.init_fun_key][1],
                    title="v(x,y):")

# dropdown menu for selecting one of the sample functions
sample_fun_input = Dropdown(
    label="choose a sample function pair or enter one below",
    menu=odesystem_settings.sample_system_names)

# Interactor for entering starting point of initial condition
interactor = my_bokeh_utils.Interactor(plot)

# initialize callback behaviour
sample_fun_input.on_click(sample_fun_change)
u_input.on_change('value', ode_change)
v_input.on_change('value', ode_change)
interactor.on_click(initial_value_change)

# calculate data
init_data()

# lists all the controls in our app associated with the default_funs panel
function_controls = widgetbox(sample_fun_input, u_input, v_input, width=400)

# refresh quiver field and streamline all 100ms
curdoc().add_periodic_callback(refresh_user_view, 100)
# make layout
curdoc().add_root(row(function_controls, plot))
Beispiel #8
0
    dictionary['xdata3'] = dictionary[new.item]
    dict['x3'] = new.item
    ds3.data = dictionary
    plot3.xaxis.axis_label = new.item
    update_tooltip()


def dropdownX_handler4(new):
    dictionary['xdata4'] = dictionary[new.item]
    dict['x4'] = new.item
    ds4.data = dictionary
    plot4.xaxis.axis_label = new.item
    update_tooltip()


dropdownX1.on_click(dropdownX_handler1)
dropdownX2.on_click(dropdownX_handler2)
dropdownX3.on_click(dropdownX_handler3)
dropdownX4.on_click(dropdownX_handler4)

dropdownY1 = Dropdown(label="Y value1", button_type="default", menu=menu)
dropdownY2 = Dropdown(label="Y value2", button_type="default", menu=menu)
dropdownY3 = Dropdown(label="Y value3", button_type="default", menu=menu)
dropdownY4 = Dropdown(label="Y value4", button_type="default", menu=menu)


def dropdownY_handler1(new):
    dictionary['ydata1'] = dictionary[new.item]
    dict['y1'] = new.item
    ds1.data = dictionary
    plot1.yaxis.axis_label = new.item
Beispiel #9
0
def create():
    det_data = []
    fit_params = {}
    js_data = ColumnDataSource(data=dict(content=["", ""], fname=["", ""]))

    def proposal_textinput_callback(_attr, _old, new):
        proposal = new.strip()
        for zebra_proposals_path in pyzebra.ZEBRA_PROPOSALS_PATHS:
            proposal_path = os.path.join(zebra_proposals_path, proposal)
            if os.path.isdir(proposal_path):
                # found it
                break
        else:
            raise ValueError(f"Can not find data for proposal '{proposal}'.")

        file_list = []
        for file in os.listdir(proposal_path):
            if file.endswith((".ccl", ".dat")):
                file_list.append((os.path.join(proposal_path, file), file))
        file_select.options = file_list
        file_open_button.disabled = False
        file_append_button.disabled = False

    proposal_textinput = TextInput(title="Proposal number:", width=210)
    proposal_textinput.on_change("value", proposal_textinput_callback)

    def _init_datatable():
        scan_list = [s["idx"] for s in det_data]
        file_list = []
        for scan in det_data:
            file_list.append(os.path.basename(scan["original_filename"]))

        scan_table_source.data.update(
            file=file_list,
            scan=scan_list,
            param=[None] * len(scan_list),
            fit=[0] * len(scan_list),
            export=[True] * len(scan_list),
        )
        scan_table_source.selected.indices = []
        scan_table_source.selected.indices = [0]

        scan_motor_select.options = det_data[0]["scan_motors"]
        scan_motor_select.value = det_data[0]["scan_motor"]
        param_select.value = "user defined"

    file_select = MultiSelect(title="Available .ccl/.dat files:",
                              width=210,
                              height=250)

    def file_open_button_callback():
        nonlocal det_data
        det_data = []
        for f_name in file_select.value:
            with open(f_name) as file:
                base, ext = os.path.splitext(f_name)
                if det_data:
                    append_data = pyzebra.parse_1D(file, ext)
                    pyzebra.normalize_dataset(append_data,
                                              monitor_spinner.value)
                    det_data.extend(append_data)
                else:
                    det_data = pyzebra.parse_1D(file, ext)
                    pyzebra.normalize_dataset(det_data, monitor_spinner.value)
                    js_data.data.update(
                        fname=[base + ".comm", base + ".incomm"])

        _init_datatable()
        append_upload_button.disabled = False

    file_open_button = Button(label="Open New", width=100, disabled=True)
    file_open_button.on_click(file_open_button_callback)

    def file_append_button_callback():
        for f_name in file_select.value:
            with open(f_name) as file:
                _, ext = os.path.splitext(f_name)
                append_data = pyzebra.parse_1D(file, ext)

            pyzebra.normalize_dataset(append_data, monitor_spinner.value)
            det_data.extend(append_data)

        _init_datatable()

    file_append_button = Button(label="Append", width=100, disabled=True)
    file_append_button.on_click(file_append_button_callback)

    def upload_button_callback(_attr, _old, new):
        nonlocal det_data
        det_data = []
        for f_str, f_name in zip(new, upload_button.filename):
            with io.StringIO(base64.b64decode(f_str).decode()) as file:
                base, ext = os.path.splitext(f_name)
                if det_data:
                    append_data = pyzebra.parse_1D(file, ext)
                    pyzebra.normalize_dataset(append_data,
                                              monitor_spinner.value)
                    det_data.extend(append_data)
                else:
                    det_data = pyzebra.parse_1D(file, ext)
                    pyzebra.normalize_dataset(det_data, monitor_spinner.value)
                    js_data.data.update(
                        fname=[base + ".comm", base + ".incomm"])

        _init_datatable()
        append_upload_button.disabled = False

    upload_div = Div(text="or upload new .ccl/.dat files:",
                     margin=(5, 5, 0, 5))
    upload_button = FileInput(accept=".ccl,.dat", multiple=True, width=200)
    upload_button.on_change("value", upload_button_callback)

    def append_upload_button_callback(_attr, _old, new):
        for f_str, f_name in zip(new, append_upload_button.filename):
            with io.StringIO(base64.b64decode(f_str).decode()) as file:
                _, ext = os.path.splitext(f_name)
                append_data = pyzebra.parse_1D(file, ext)

            pyzebra.normalize_dataset(append_data, monitor_spinner.value)
            det_data.extend(append_data)

        _init_datatable()

    append_upload_div = Div(text="append extra files:", margin=(5, 5, 0, 5))
    append_upload_button = FileInput(accept=".ccl,.dat",
                                     multiple=True,
                                     width=200,
                                     disabled=True)
    append_upload_button.on_change("value", append_upload_button_callback)

    def monitor_spinner_callback(_attr, _old, new):
        if det_data:
            pyzebra.normalize_dataset(det_data, new)
            _update_plot()

    monitor_spinner = Spinner(title="Monitor:",
                              mode="int",
                              value=100_000,
                              low=1,
                              width=145)
    monitor_spinner.on_change("value", monitor_spinner_callback)

    def scan_motor_select_callback(_attr, _old, new):
        if det_data:
            for scan in det_data:
                scan["scan_motor"] = new
            _update_plot()

    scan_motor_select = Select(title="Scan motor:", options=[], width=145)
    scan_motor_select.on_change("value", scan_motor_select_callback)

    def _update_table():
        fit_ok = [(1 if "fit" in scan else 0) for scan in det_data]
        scan_table_source.data.update(fit=fit_ok)

    def _update_plot():
        _update_single_scan_plot(_get_selected_scan())
        _update_overview()

    def _update_single_scan_plot(scan):
        scan_motor = scan["scan_motor"]

        y = scan["counts"]
        x = scan[scan_motor]

        plot.axis[0].axis_label = scan_motor
        plot_scatter_source.data.update(x=x,
                                        y=y,
                                        y_upper=y + np.sqrt(y),
                                        y_lower=y - np.sqrt(y))

        fit = scan.get("fit")
        if fit is not None:
            x_fit = np.linspace(x[0], x[-1], 100)
            plot_fit_source.data.update(x=x_fit, y=fit.eval(x=x_fit))

            x_bkg = []
            y_bkg = []
            xs_peak = []
            ys_peak = []
            comps = fit.eval_components(x=x_fit)
            for i, model in enumerate(fit_params):
                if "linear" in model:
                    x_bkg = x_fit
                    y_bkg = comps[f"f{i}_"]

                elif any(val in model
                         for val in ("gaussian", "voigt", "pvoigt")):
                    xs_peak.append(x_fit)
                    ys_peak.append(comps[f"f{i}_"])

            plot_bkg_source.data.update(x=x_bkg, y=y_bkg)
            plot_peak_source.data.update(xs=xs_peak, ys=ys_peak)

            fit_output_textinput.value = fit.fit_report()

        else:
            plot_fit_source.data.update(x=[], y=[])
            plot_bkg_source.data.update(x=[], y=[])
            plot_peak_source.data.update(xs=[], ys=[])
            fit_output_textinput.value = ""

    def _update_overview():
        xs = []
        ys = []
        param = []
        x = []
        y = []
        par = []
        for s, p in enumerate(scan_table_source.data["param"]):
            if p is not None:
                scan = det_data[s]
                scan_motor = scan["scan_motor"]
                xs.append(scan[scan_motor])
                x.extend(scan[scan_motor])
                ys.append(scan["counts"])
                y.extend([float(p)] * len(scan[scan_motor]))
                param.append(float(p))
                par.extend(scan["counts"])

        if det_data:
            scan_motor = det_data[0]["scan_motor"]
            ov_plot.axis[0].axis_label = scan_motor
            ov_param_plot.axis[0].axis_label = scan_motor

        ov_plot_mline_source.data.update(xs=xs,
                                         ys=ys,
                                         param=param,
                                         color=color_palette(len(xs)))

        if y:
            mapper["transform"].low = np.min([np.min(y) for y in ys])
            mapper["transform"].high = np.max([np.max(y) for y in ys])
        ov_param_plot_scatter_source.data.update(x=x, y=y, param=par)

        if y:
            interp_f = interpolate.interp2d(x, y, par)
            x1, x2 = min(x), max(x)
            y1, y2 = min(y), max(y)
            image = interp_f(
                np.linspace(x1, x2, ov_param_plot.inner_width // 10),
                np.linspace(y1, y2, ov_param_plot.inner_height // 10),
                assume_sorted=True,
            )
            ov_param_plot_image_source.data.update(image=[image],
                                                   x=[x1],
                                                   y=[y1],
                                                   dw=[x2 - x1],
                                                   dh=[y2 - y1])
        else:
            ov_param_plot_image_source.data.update(image=[],
                                                   x=[],
                                                   y=[],
                                                   dw=[],
                                                   dh=[])

    def _update_param_plot():
        x = []
        y = []
        fit_param = fit_param_select.value
        for s, p in zip(det_data, scan_table_source.data["param"]):
            if "fit" in s and fit_param:
                x.append(p)
                y.append(s["fit"].values[fit_param])

        param_plot_scatter_source.data.update(x=x, y=y)

    # Main plot
    plot = Plot(
        x_range=DataRange1d(),
        y_range=DataRange1d(only_visible=True),
        plot_height=450,
        plot_width=700,
    )

    plot.add_layout(LinearAxis(axis_label="Counts"), place="left")
    plot.add_layout(LinearAxis(axis_label="Scan motor"), place="below")

    plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    plot_scatter_source = ColumnDataSource(
        dict(x=[0], y=[0], y_upper=[0], y_lower=[0]))
    plot_scatter = plot.add_glyph(
        plot_scatter_source, Scatter(x="x", y="y", line_color="steelblue"))
    plot.add_layout(
        Whisker(source=plot_scatter_source,
                base="x",
                upper="y_upper",
                lower="y_lower"))

    plot_fit_source = ColumnDataSource(dict(x=[0], y=[0]))
    plot_fit = plot.add_glyph(plot_fit_source, Line(x="x", y="y"))

    plot_bkg_source = ColumnDataSource(dict(x=[0], y=[0]))
    plot_bkg = plot.add_glyph(
        plot_bkg_source,
        Line(x="x", y="y", line_color="green", line_dash="dashed"))

    plot_peak_source = ColumnDataSource(dict(xs=[[0]], ys=[[0]]))
    plot_peak = plot.add_glyph(
        plot_peak_source,
        MultiLine(xs="xs", ys="ys", line_color="red", line_dash="dashed"))

    fit_from_span = Span(location=None, dimension="height", line_dash="dashed")
    plot.add_layout(fit_from_span)

    fit_to_span = Span(location=None, dimension="height", line_dash="dashed")
    plot.add_layout(fit_to_span)

    plot.add_layout(
        Legend(
            items=[
                ("data", [plot_scatter]),
                ("best fit", [plot_fit]),
                ("peak", [plot_peak]),
                ("linear", [plot_bkg]),
            ],
            location="top_left",
            click_policy="hide",
        ))

    plot.add_tools(PanTool(), WheelZoomTool(), ResetTool())
    plot.toolbar.logo = None

    # Overview multilines plot
    ov_plot = Plot(x_range=DataRange1d(),
                   y_range=DataRange1d(),
                   plot_height=450,
                   plot_width=700)

    ov_plot.add_layout(LinearAxis(axis_label="Counts"), place="left")
    ov_plot.add_layout(LinearAxis(axis_label="Scan motor"), place="below")

    ov_plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    ov_plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    ov_plot_mline_source = ColumnDataSource(
        dict(xs=[], ys=[], param=[], color=[]))
    ov_plot.add_glyph(ov_plot_mline_source,
                      MultiLine(xs="xs", ys="ys", line_color="color"))

    hover_tool = HoverTool(tooltips=[("param", "@param")])
    ov_plot.add_tools(PanTool(), WheelZoomTool(), hover_tool, ResetTool())

    ov_plot.add_tools(PanTool(), WheelZoomTool(), ResetTool())
    ov_plot.toolbar.logo = None

    # Overview perams plot
    ov_param_plot = Plot(x_range=DataRange1d(),
                         y_range=DataRange1d(),
                         plot_height=450,
                         plot_width=700)

    ov_param_plot.add_layout(LinearAxis(axis_label="Param"), place="left")
    ov_param_plot.add_layout(LinearAxis(axis_label="Scan motor"),
                             place="below")

    ov_param_plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    ov_param_plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    ov_param_plot_image_source = ColumnDataSource(
        dict(image=[], x=[], y=[], dw=[], dh=[]))
    ov_param_plot.add_glyph(
        ov_param_plot_image_source,
        Image(image="image", x="x", y="y", dw="dw", dh="dh"))

    ov_param_plot_scatter_source = ColumnDataSource(dict(x=[], y=[], param=[]))
    mapper = linear_cmap(field_name="param", palette=Turbo256, low=0, high=50)
    ov_param_plot.add_glyph(
        ov_param_plot_scatter_source,
        Scatter(x="x", y="y", line_color=mapper, fill_color=mapper, size=10),
    )

    ov_param_plot.add_tools(PanTool(), WheelZoomTool(), ResetTool())
    ov_param_plot.toolbar.logo = None

    # Parameter plot
    param_plot = Plot(x_range=DataRange1d(),
                      y_range=DataRange1d(),
                      plot_height=400,
                      plot_width=700)

    param_plot.add_layout(LinearAxis(axis_label="Fit parameter"), place="left")
    param_plot.add_layout(LinearAxis(axis_label="Parameter"), place="below")

    param_plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    param_plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    param_plot_scatter_source = ColumnDataSource(dict(x=[], y=[]))
    param_plot.add_glyph(param_plot_scatter_source, Scatter(x="x", y="y"))

    param_plot.add_tools(PanTool(), WheelZoomTool(), ResetTool())
    param_plot.toolbar.logo = None

    def fit_param_select_callback(_attr, _old, _new):
        _update_param_plot()

    fit_param_select = Select(title="Fit parameter", options=[], width=145)
    fit_param_select.on_change("value", fit_param_select_callback)

    # Plot tabs
    plots = Tabs(tabs=[
        Panel(child=plot, title="single scan"),
        Panel(child=ov_plot, title="overview"),
        Panel(child=ov_param_plot, title="overview map"),
        Panel(child=column(param_plot, row(fit_param_select)),
              title="parameter plot"),
    ])

    # Scan select
    def scan_table_select_callback(_attr, old, new):
        if not new:
            # skip empty selections
            return

        # Avoid selection of multiple indicies (via Shift+Click or Ctrl+Click)
        if len(new) > 1:
            # drop selection to the previous one
            scan_table_source.selected.indices = old
            return

        if len(old) > 1:
            # skip unnecessary update caused by selection drop
            return

        _update_plot()

    def scan_table_source_callback(_attr, _old, _new):
        _update_preview()

    scan_table_source = ColumnDataSource(
        dict(file=[], scan=[], param=[], fit=[], export=[]))
    scan_table_source.on_change("data", scan_table_source_callback)

    scan_table = DataTable(
        source=scan_table_source,
        columns=[
            TableColumn(field="file", title="file", width=150),
            TableColumn(field="scan", title="scan", width=50),
            TableColumn(field="param",
                        title="param",
                        editor=NumberEditor(),
                        width=50),
            TableColumn(field="fit", title="Fit", width=50),
            TableColumn(field="export",
                        title="Export",
                        editor=CheckboxEditor(),
                        width=50),
        ],
        width=410,  # +60 because of the index column
        editable=True,
        autosize_mode="none",
    )

    def scan_table_source_callback(_attr, _old, _new):
        if scan_table_source.selected.indices:
            _update_plot()

    scan_table_source.selected.on_change("indices", scan_table_select_callback)
    scan_table_source.on_change("data", scan_table_source_callback)

    def _get_selected_scan():
        return det_data[scan_table_source.selected.indices[0]]

    def param_select_callback(_attr, _old, new):
        if new == "user defined":
            param = [None] * len(det_data)
        else:
            param = [scan[new] for scan in det_data]

        scan_table_source.data["param"] = param
        _update_param_plot()

    param_select = Select(
        title="Parameter:",
        options=["user defined", "temp", "mf", "h", "k", "l"],
        value="user defined",
        width=145,
    )
    param_select.on_change("value", param_select_callback)

    def fit_from_spinner_callback(_attr, _old, new):
        fit_from_span.location = new

    fit_from_spinner = Spinner(title="Fit from:", width=145)
    fit_from_spinner.on_change("value", fit_from_spinner_callback)

    def fit_to_spinner_callback(_attr, _old, new):
        fit_to_span.location = new

    fit_to_spinner = Spinner(title="to:", width=145)
    fit_to_spinner.on_change("value", fit_to_spinner_callback)

    def fitparams_add_dropdown_callback(click):
        # bokeh requires (str, str) for MultiSelect options
        new_tag = f"{click.item}-{fitparams_select.tags[0]}"
        fitparams_select.options.append((new_tag, click.item))
        fit_params[new_tag] = fitparams_factory(click.item)
        fitparams_select.tags[0] += 1

    fitparams_add_dropdown = Dropdown(
        label="Add fit function",
        menu=[
            ("Linear", "linear"),
            ("Gaussian", "gaussian"),
            ("Voigt", "voigt"),
            ("Pseudo Voigt", "pvoigt"),
            # ("Pseudo Voigt1", "pseudovoigt1"),
        ],
        width=145,
    )
    fitparams_add_dropdown.on_click(fitparams_add_dropdown_callback)

    def fitparams_select_callback(_attr, old, new):
        # Avoid selection of multiple indicies (via Shift+Click or Ctrl+Click)
        if len(new) > 1:
            # drop selection to the previous one
            fitparams_select.value = old
            return

        if len(old) > 1:
            # skip unnecessary update caused by selection drop
            return

        if new:
            fitparams_table_source.data.update(fit_params[new[0]])
        else:
            fitparams_table_source.data.update(
                dict(param=[], value=[], vary=[], min=[], max=[]))

    fitparams_select = MultiSelect(options=[], height=120, width=145)
    fitparams_select.tags = [0]
    fitparams_select.on_change("value", fitparams_select_callback)

    def fitparams_remove_button_callback():
        if fitparams_select.value:
            sel_tag = fitparams_select.value[0]
            del fit_params[sel_tag]
            for elem in fitparams_select.options:
                if elem[0] == sel_tag:
                    fitparams_select.options.remove(elem)
                    break

            fitparams_select.value = []

    fitparams_remove_button = Button(label="Remove fit function", width=145)
    fitparams_remove_button.on_click(fitparams_remove_button_callback)

    def fitparams_factory(function):
        if function == "linear":
            params = ["slope", "intercept"]
        elif function == "gaussian":
            params = ["amplitude", "center", "sigma"]
        elif function == "voigt":
            params = ["amplitude", "center", "sigma", "gamma"]
        elif function == "pvoigt":
            params = ["amplitude", "center", "sigma", "fraction"]
        elif function == "pseudovoigt1":
            params = ["amplitude", "center", "g_sigma", "l_sigma", "fraction"]
        else:
            raise ValueError("Unknown fit function")

        n = len(params)
        fitparams = dict(
            param=params,
            value=[None] * n,
            vary=[True] * n,
            min=[None] * n,
            max=[None] * n,
        )

        if function == "linear":
            fitparams["value"] = [0, 1]
            fitparams["vary"] = [False, True]
            fitparams["min"] = [None, 0]

        elif function == "gaussian":
            fitparams["min"] = [0, None, None]

        return fitparams

    fitparams_table_source = ColumnDataSource(
        dict(param=[], value=[], vary=[], min=[], max=[]))
    fitparams_table = DataTable(
        source=fitparams_table_source,
        columns=[
            TableColumn(field="param", title="Parameter"),
            TableColumn(field="value", title="Value", editor=NumberEditor()),
            TableColumn(field="vary", title="Vary", editor=CheckboxEditor()),
            TableColumn(field="min", title="Min", editor=NumberEditor()),
            TableColumn(field="max", title="Max", editor=NumberEditor()),
        ],
        height=200,
        width=350,
        index_position=None,
        editable=True,
        auto_edit=True,
    )

    # start with `background` and `gauss` fit functions added
    fitparams_add_dropdown_callback(types.SimpleNamespace(item="linear"))
    fitparams_add_dropdown_callback(types.SimpleNamespace(item="gaussian"))
    fitparams_select.value = ["gaussian-1"]  # add selection to gauss

    fit_output_textinput = TextAreaInput(title="Fit results:",
                                         width=750,
                                         height=200)

    def proc_all_button_callback():
        for scan, export in zip(det_data, scan_table_source.data["export"]):
            if export:
                pyzebra.fit_scan(scan,
                                 fit_params,
                                 fit_from=fit_from_spinner.value,
                                 fit_to=fit_to_spinner.value)
                pyzebra.get_area(
                    scan,
                    area_method=AREA_METHODS[area_method_radiobutton.active],
                    lorentz=lorentz_checkbox.active,
                )

        _update_plot()
        _update_table()

        for scan in det_data:
            if "fit" in scan:
                options = list(scan["fit"].params.keys())
                fit_param_select.options = options
                fit_param_select.value = options[0]
                break
        _update_param_plot()

    proc_all_button = Button(label="Process All",
                             button_type="primary",
                             width=145)
    proc_all_button.on_click(proc_all_button_callback)

    def proc_button_callback():
        scan = _get_selected_scan()
        pyzebra.fit_scan(scan,
                         fit_params,
                         fit_from=fit_from_spinner.value,
                         fit_to=fit_to_spinner.value)
        pyzebra.get_area(
            scan,
            area_method=AREA_METHODS[area_method_radiobutton.active],
            lorentz=lorentz_checkbox.active,
        )

        _update_plot()
        _update_table()

        for scan in det_data:
            if "fit" in scan:
                options = list(scan["fit"].params.keys())
                fit_param_select.options = options
                fit_param_select.value = options[0]
                break
        _update_param_plot()

    proc_button = Button(label="Process Current", width=145)
    proc_button.on_click(proc_button_callback)

    area_method_div = Div(text="Intensity:", margin=(5, 5, 0, 5))
    area_method_radiobutton = RadioGroup(labels=["Function", "Area"],
                                         active=0,
                                         width=145)

    lorentz_checkbox = CheckboxGroup(labels=["Lorentz Correction"],
                                     width=145,
                                     margin=(13, 5, 5, 5))

    export_preview_textinput = TextAreaInput(title="Export file preview:",
                                             width=450,
                                             height=400)

    def _update_preview():
        with tempfile.TemporaryDirectory() as temp_dir:
            temp_file = temp_dir + "/temp"
            export_data = []
            for s, export in zip(det_data, scan_table_source.data["export"]):
                if export:
                    export_data.append(s)

            # pyzebra.export_1D(export_data, temp_file, "fullprof")

            exported_content = ""
            file_content = []
            for ext in (".comm", ".incomm"):
                fname = temp_file + ext
                if os.path.isfile(fname):
                    with open(fname) as f:
                        content = f.read()
                        exported_content += f"{ext} file:\n" + content
                else:
                    content = ""
                file_content.append(content)

            js_data.data.update(content=file_content)
            export_preview_textinput.value = exported_content

    save_button = Button(label="Download File(s)",
                         button_type="success",
                         width=220)
    save_button.js_on_click(
        CustomJS(args={"js_data": js_data}, code=javaScript))

    fitpeak_controls = row(
        column(fitparams_add_dropdown, fitparams_select,
               fitparams_remove_button),
        fitparams_table,
        Spacer(width=20),
        column(fit_from_spinner, lorentz_checkbox, area_method_div,
               area_method_radiobutton),
        column(fit_to_spinner, proc_button, proc_all_button),
    )

    scan_layout = column(scan_table,
                         row(monitor_spinner, scan_motor_select, param_select))

    import_layout = column(
        proposal_textinput,
        file_select,
        row(file_open_button, file_append_button),
        upload_div,
        upload_button,
        append_upload_div,
        append_upload_button,
    )

    export_layout = column(export_preview_textinput, row(save_button))

    tab_layout = column(
        row(import_layout, scan_layout, plots, Spacer(width=30),
            export_layout),
        row(fitpeak_controls, fit_output_textinput),
    )

    return Panel(child=tab_layout, title="param study")
Beispiel #10
0
    def bkapp(doc):
        def fitButtonCallback(new):
            x_range = [p.x_range.start, p.x_range.end]
            y_range = [p.y_range.start, p.y_range.end]
            x_inds = np.where(np.logical_and(bin_mids >= x_range[0], bin_mids <= x_range[1]), bin_mids, np.nan)
            x = x_inds[np.isfinite(x_inds)]
            y = hist[np.isfinite(x_inds)]
            x_fit = np.linspace(x_range[0], x_range[1], 10000)
            p0 = [y.max(), x[np.argmax(y)], x[np.argmax(y)], 0]
            p_fit = curve_fit(_gaus, x, y, p0, ftol=1e-2)[0]
            y_fit = _gaus(x_fit, *p_fit)
            # p.line(x_fit, y_fit, legend_label='Amp. = %5.2f, Mean = %5.2f, Std. = %5.2f, Base = %5.2f' % p_fit)
            fitDataSource.patch(dict(x=[(slice(0, len(x_fit)), x_fit)], y=[(slice(0, len(y_fit)), y_fit)]))
            fit_str = 'Amp. = %5.2f, Mean = %5.2f, Std. = %5.2f, Base = %5.2f' % (p_fit[0], p_fit[1], p_fit[2], p_fit[3])
            #p.line(x=x_fit, y=y_fit, legend_label=new_data['legend_label'], color='black')
            p.x_range = Range1d(bins[0], bins[-1])
            p.y_range = Range1d(0, hists[0].max())
            p.legend.items = [(settings['sz_key'], [hist_line]), (fit_str, [fit_line])]

        def SzDropDownCallback(new):
            settings['sz_key'] = new.item
            bins = np.linspace(0, 130000, 1000)
            size_datasets = pysp2.util.calc_diams_masses(my_cal_data[settings['sz_key']])
            hist, bins = np.histogram(size_datasets[settings['variable']].where(size_datasets.ScatRejectKey == 0).values,
                                      bins=bins)
            histDataSource.patch(dict(x=[(slice(0, len(bin_mids)), bin_mids)], y=[(slice(0, len(hist)), hist)]))
            p.legend.items = [(settings['sz_key'], [hist_line]), ("", [fit_line])]

        def VariableCallback(new):
            settings['variable'] = new.item
            bins = np.linspace(0, 130000, 1000)
            size_datasets = pysp2.util.calc_diams_masses(my_cal_data[settings['sz_key']])
            hist, bins = np.histogram(size_datasets[settings['variable']].where(size_datasets.ScatRejectKey == 0).values,
                                      bins=bins)
            histDataSource.patch(dict(x=[(slice(0, len(bin_mids)), bin_mids)], y=[(slice(0, len(hist)), hist)]))
            p.legend.items = [(settings['sz_key'], [hist_line]), ("", [fit_line])]

        settings = {'sz_key':'scat_200', 'variable': 'PkHt_ch0'}

        sizes = np.array([350, 300, 269, 220, 200, 170])
        scat_mean = np.ones_like(sizes, dtype=float)

        size_datasets = {}
        i = 0
        # bins = np.logspace(-13, -10, 120)
        bins = np.linspace(0, 130000, 1000)
        hists = np.zeros((len(sizes), len(bins) - 1))
        size_datasets = pysp2.util.calc_diams_masses(my_cal_data[settings['sz_key']])
        hist, bins = np.histogram(size_datasets[settings['variable']].where(size_datasets.ScatRejectKey == 0).values,
                                  bins=bins)
        p = figure(tools="pan,box_zoom,reset,save", x_range=(bins[0], bins[-1]), y_range=(0, hist.max()))
        bin_mids = (bins[:-1] + bins[1:])/2.0
        histDataSource = ColumnDataSource(data=dict(x=bin_mids, y=hist))
        x_fit = np.linspace(0, 130000, 10000)
        fitDataSource = ColumnDataSource(data=dict(x=x_fit, y=np.nan*np.ones_like(x_fit)))
        hist_line = p.line('x', 'y', source=histDataSource, legend_label=settings['sz_key'])
        fit_line = p.line('x', 'y', source=fitDataSource, legend_label="")
        button = Button(label="Curve fit")
        button.on_click(fitButtonCallback)
        size_menu = []
        for keys in my_cal_data.keys():
            size_menu.append((keys, keys))
        SzDropDown = Dropdown(label="Calibration data", menu=size_menu)
        SzDropDown.on_click(SzDropDownCallback)
        vars = ["PkHt_ch0", "PkHt_ch1", "PkHt_ch4", "PkHt_ch5", "FtAmp_ch0", 'FtAmp_ch4']
        vars = [(x, x) for x in vars]
        VarDropDown = Dropdown(label="Variable", menu=vars)
        VarDropDown.on_click(VariableCallback)
        doc.add_root(column(button, SzDropDown, VarDropDown, p))
Beispiel #11
0
class BokehCorpusAnnotator(BokehCorpusExplorer):
    """
    Annoate text data points via callbacks.

    Features:

    - alter values in the 'label' column through the widgets.
    - **SERVER ONLY**: only works in a setting that allows Python callbacks.
    """

    DATA_KEY_TO_KWARGS = {
        _key: {
            "constant": {
                "line_alpha": 0.3
            },
            "search": {
                "size": ("size", 10, 5, 7),
                "fill_alpha": ("fill_alpha", 0.3, 0.1, 0.2),
            },
        }
        for _key in ["raw", "train", "dev", "test"]
    }

    def __init__(self, df_dict, **kwargs):
        """Conceptually the same as the parent method."""
        super().__init__(df_dict, **kwargs)

    def _layout_widgets(self):
        """Define the layout of widgets."""
        layout_rows = (
            row(self.search_pos, self.search_neg),
            row(self.data_key_button_group),
            row(self.annotator_input, self.annotator_apply,
                self.annotator_export),
        )
        return column(*layout_rows)

    def _setup_widgets(self):
        """
        Create annotator widgets and assign Python callbacks.
        """
        from bokeh.models import TextInput, Button, Dropdown

        super()._setup_widgets()

        self.annotator_input = TextInput(title="Label:")
        self.annotator_apply = Button(
            label="Apply",
            button_type="primary",
            height_policy="fit",
            width_policy="min",
        )
        self.annotator_export = Dropdown(
            label="Export",
            button_type="warning",
            menu=["Excel", "CSV", "JSON", "pickle"],
            height_policy="fit",
            width_policy="min",
        )

        def callback_apply():
            """
            A callback on clicking the 'self.annotator_apply' button.

            Update labels in the source.
            """
            label = self.annotator_input.value
            selected_idx = self.sources["raw"].selected.indices
            if not selected_idx:
                self._warn(
                    "Attempting annotation: did not select any data points. Eligible subset is 'raw'."
                )
                return
            example_old = self.dfs["raw"].at[selected_idx[0], "label"]
            self.dfs["raw"].at[selected_idx, "label"] = label
            example_new = self.dfs["raw"].at[selected_idx[0], "label"]
            self._good(
                f"Applied {len(selected_idx)} annotations: {label} (e.g. {example_old} -> {example_new}"
            )

            self._update_sources()
            self.plot()
            self._good(f"Updated annotator plot at {current_time()}")

        def callback_export(event, path_root=None):
            """
            A callback on clicking the 'self.annotator_export' button.

            Saves the dataframe to a pickle.
            """
            export_format = event.item

            # auto-determine the export path root
            if path_root is None:
                timestamp = current_time("%Y%m%d%H%M%S")
                path_root = f"hover-annotated-df-{timestamp}"

            export_df = pd.concat(self.dfs,
                                  axis=0,
                                  sort=False,
                                  ignore_index=True)

            if export_format == "Excel":
                export_path = f"{path_root}.xlsx"
                export_df.to_excel(export_path, index=False)
            elif export_format == "CSV":
                export_path = f"{path_root}.csv"
                export_df.to_csv(export_path, index=False)
            elif export_format == "JSON":
                export_path = f"{path_root}.json"
                export_df.to_json(export_path, orient="records")
            elif export_format == "pickle":
                export_path = f"{path_root}.pkl"
                export_df.to_pickle(export_path)
            else:
                raise ValueError(f"Unexpected export format {export_format}")

            self._good(f"Saved DataFrame to {export_path}")

        # keep the references to the callbacks
        self._callback_apply = callback_apply
        self._callback_export = callback_export

        # assign callbacks
        self.annotator_apply.on_click(self._callback_apply)
        self.annotator_export.on_click(self._callback_export)

    def plot(self):
        """
        Re-plot with the new labels.

        Overrides the parent method.
        Determines the label->color mapping dynamically.
        """
        labels, cmap = self.auto_labels_cmap()

        for _key, _source in self.sources.items():
            self.figure.circle(
                "x",
                "y",
                name=_key,
                color=factor_cmap("label", cmap, labels),
                legend_group="label",
                source=_source,
                **self.glyph_kwargs[_key],
            )
            self._good(
                f"Plotted subset {_key} with {self.dfs[_key].shape[0]} points")

        self.auto_legend_correction()
source_function1 = ColumnDataSource(data=dict(x=[], y=[]))
source_function2 = ColumnDataSource(data=dict(x=[], y=[]))
source_result = ColumnDataSource(data=dict(x=[], y=[]))
source_convolution = ColumnDataSource(data=dict(x=[], y=[]))
source_xmarker = ColumnDataSource(data=dict(x=[], y=[]))
source_overlay = ColumnDataSource(data=dict(x=[], y=[], y_neg=[], y_pos=[]))

# initialize properties
update_is_enabled = True

# initialize controls
# dropdown menu for sample functions
function_type = Dropdown(
    label="choose a sample function pair or enter one below",
    menu=convolution_settings.sample_function_names)
function_type.on_click(function_pair_input_change)

# slider controlling the evaluated x value of the convolved function
x_value_input = Slider(title="x value",
                       name='x value',
                       value=convolution_settings.x_value_init,
                       start=convolution_settings.x_value_min,
                       end=convolution_settings.x_value_max,
                       step=convolution_settings.x_value_step)
x_value_input.on_change('value', input_change)
# text input for the first function to be convolved
function1_input = TextInput(value=convolution_settings.function1_input_init,
                            title="my first function:")
function1_input.on_change('value', input_change)
# text input for the second function to be convolved
function2_input = TextInput(value=convolution_settings.function1_input_init,
Beispiel #13
0
def create():
    det_data = {}
    fit_params = {}
    js_data = ColumnDataSource(
        data=dict(content=["", ""], fname=["", ""], ext=["", ""]))

    def proposal_textinput_callback(_attr, _old, new):
        proposal = new.strip()
        for zebra_proposals_path in pyzebra.ZEBRA_PROPOSALS_PATHS:
            proposal_path = os.path.join(zebra_proposals_path, proposal)
            if os.path.isdir(proposal_path):
                # found it
                break
        else:
            raise ValueError(f"Can not find data for proposal '{proposal}'.")

        file_list = []
        for file in os.listdir(proposal_path):
            if file.endswith((".ccl", ".dat")):
                file_list.append((os.path.join(proposal_path, file), file))
        file_select.options = file_list
        file_open_button.disabled = False
        file_append_button.disabled = False

    proposal_textinput = TextInput(title="Proposal number:", width=210)
    proposal_textinput.on_change("value", proposal_textinput_callback)

    def _init_datatable():
        scan_list = [s["idx"] for s in det_data]
        hkl = [f'{s["h"]} {s["k"]} {s["l"]}' for s in det_data]
        export = [s.get("active", True) for s in det_data]
        scan_table_source.data.update(
            scan=scan_list,
            hkl=hkl,
            fit=[0] * len(scan_list),
            export=export,
        )
        scan_table_source.selected.indices = []
        scan_table_source.selected.indices = [0]

        merge_options = [(str(i), f"{i} ({idx})")
                         for i, idx in enumerate(scan_list)]
        merge_from_select.options = merge_options
        merge_from_select.value = merge_options[0][0]

    file_select = MultiSelect(title="Available .ccl/.dat files:",
                              width=210,
                              height=250)

    def file_open_button_callback():
        nonlocal det_data
        det_data = []
        for f_name in file_select.value:
            with open(f_name) as file:
                base, ext = os.path.splitext(f_name)
                if det_data:
                    append_data = pyzebra.parse_1D(file, ext)
                    pyzebra.normalize_dataset(append_data,
                                              monitor_spinner.value)
                    pyzebra.merge_datasets(det_data, append_data)
                else:
                    det_data = pyzebra.parse_1D(file, ext)
                    pyzebra.normalize_dataset(det_data, monitor_spinner.value)
                    pyzebra.merge_duplicates(det_data)
                    js_data.data.update(fname=[base, base])

        _init_datatable()
        append_upload_button.disabled = False

    file_open_button = Button(label="Open New", width=100, disabled=True)
    file_open_button.on_click(file_open_button_callback)

    def file_append_button_callback():
        for f_name in file_select.value:
            with open(f_name) as file:
                _, ext = os.path.splitext(f_name)
                append_data = pyzebra.parse_1D(file, ext)

            pyzebra.normalize_dataset(append_data, monitor_spinner.value)
            pyzebra.merge_datasets(det_data, append_data)

        _init_datatable()

    file_append_button = Button(label="Append", width=100, disabled=True)
    file_append_button.on_click(file_append_button_callback)

    def upload_button_callback(_attr, _old, new):
        nonlocal det_data
        det_data = []
        for f_str, f_name in zip(new, upload_button.filename):
            with io.StringIO(base64.b64decode(f_str).decode()) as file:
                base, ext = os.path.splitext(f_name)
                if det_data:
                    append_data = pyzebra.parse_1D(file, ext)
                    pyzebra.normalize_dataset(append_data,
                                              monitor_spinner.value)
                    pyzebra.merge_datasets(det_data, append_data)
                else:
                    det_data = pyzebra.parse_1D(file, ext)
                    pyzebra.normalize_dataset(det_data, monitor_spinner.value)
                    pyzebra.merge_duplicates(det_data)
                    js_data.data.update(fname=[base, base])

        _init_datatable()
        append_upload_button.disabled = False

    upload_div = Div(text="or upload new .ccl/.dat files:",
                     margin=(5, 5, 0, 5))
    upload_button = FileInput(accept=".ccl,.dat", multiple=True, width=200)
    upload_button.on_change("value", upload_button_callback)

    def append_upload_button_callback(_attr, _old, new):
        for f_str, f_name in zip(new, append_upload_button.filename):
            with io.StringIO(base64.b64decode(f_str).decode()) as file:
                _, ext = os.path.splitext(f_name)
                append_data = pyzebra.parse_1D(file, ext)

            pyzebra.normalize_dataset(append_data, monitor_spinner.value)
            pyzebra.merge_datasets(det_data, append_data)

        _init_datatable()

    append_upload_div = Div(text="append extra files:", margin=(5, 5, 0, 5))
    append_upload_button = FileInput(accept=".ccl,.dat",
                                     multiple=True,
                                     width=200,
                                     disabled=True)
    append_upload_button.on_change("value", append_upload_button_callback)

    def monitor_spinner_callback(_attr, old, new):
        if det_data:
            pyzebra.normalize_dataset(det_data, new)
            _update_plot(_get_selected_scan())

    monitor_spinner = Spinner(title="Monitor:",
                              mode="int",
                              value=100_000,
                              low=1,
                              width=145)
    monitor_spinner.on_change("value", monitor_spinner_callback)

    def _update_table():
        fit_ok = [(1 if "fit" in scan else 0) for scan in det_data]
        scan_table_source.data.update(fit=fit_ok)

    def _update_plot(scan):
        scan_motor = scan["scan_motor"]

        y = scan["counts"]
        x = scan[scan_motor]

        plot.axis[0].axis_label = scan_motor
        plot_scatter_source.data.update(x=x,
                                        y=y,
                                        y_upper=y + np.sqrt(y),
                                        y_lower=y - np.sqrt(y))

        fit = scan.get("fit")
        if fit is not None:
            x_fit = np.linspace(x[0], x[-1], 100)
            plot_fit_source.data.update(x=x_fit, y=fit.eval(x=x_fit))

            x_bkg = []
            y_bkg = []
            xs_peak = []
            ys_peak = []
            comps = fit.eval_components(x=x_fit)
            for i, model in enumerate(fit_params):
                if "linear" in model:
                    x_bkg = x_fit
                    y_bkg = comps[f"f{i}_"]

                elif any(val in model
                         for val in ("gaussian", "voigt", "pvoigt")):
                    xs_peak.append(x_fit)
                    ys_peak.append(comps[f"f{i}_"])

            plot_bkg_source.data.update(x=x_bkg, y=y_bkg)
            plot_peak_source.data.update(xs=xs_peak, ys=ys_peak)

            fit_output_textinput.value = fit.fit_report()

        else:
            plot_fit_source.data.update(x=[], y=[])
            plot_bkg_source.data.update(x=[], y=[])
            plot_peak_source.data.update(xs=[], ys=[])
            fit_output_textinput.value = ""

    # Main plot
    plot = Plot(
        x_range=DataRange1d(),
        y_range=DataRange1d(only_visible=True),
        plot_height=470,
        plot_width=700,
    )

    plot.add_layout(LinearAxis(axis_label="Counts"), place="left")
    plot.add_layout(LinearAxis(axis_label="Scan motor"), place="below")

    plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    plot_scatter_source = ColumnDataSource(
        dict(x=[0], y=[0], y_upper=[0], y_lower=[0]))
    plot_scatter = plot.add_glyph(
        plot_scatter_source, Scatter(x="x", y="y", line_color="steelblue"))
    plot.add_layout(
        Whisker(source=plot_scatter_source,
                base="x",
                upper="y_upper",
                lower="y_lower"))

    plot_fit_source = ColumnDataSource(dict(x=[0], y=[0]))
    plot_fit = plot.add_glyph(plot_fit_source, Line(x="x", y="y"))

    plot_bkg_source = ColumnDataSource(dict(x=[0], y=[0]))
    plot_bkg = plot.add_glyph(
        plot_bkg_source,
        Line(x="x", y="y", line_color="green", line_dash="dashed"))

    plot_peak_source = ColumnDataSource(dict(xs=[[0]], ys=[[0]]))
    plot_peak = plot.add_glyph(
        plot_peak_source,
        MultiLine(xs="xs", ys="ys", line_color="red", line_dash="dashed"))

    fit_from_span = Span(location=None, dimension="height", line_dash="dashed")
    plot.add_layout(fit_from_span)

    fit_to_span = Span(location=None, dimension="height", line_dash="dashed")
    plot.add_layout(fit_to_span)

    plot.add_layout(
        Legend(
            items=[
                ("data", [plot_scatter]),
                ("best fit", [plot_fit]),
                ("peak", [plot_peak]),
                ("linear", [plot_bkg]),
            ],
            location="top_left",
            click_policy="hide",
        ))

    plot.add_tools(PanTool(), WheelZoomTool(), ResetTool())
    plot.toolbar.logo = None

    # Scan select
    def scan_table_select_callback(_attr, old, new):
        if not new:
            # skip empty selections
            return

        # Avoid selection of multiple indicies (via Shift+Click or Ctrl+Click)
        if len(new) > 1:
            # drop selection to the previous one
            scan_table_source.selected.indices = old
            return

        if len(old) > 1:
            # skip unnecessary update caused by selection drop
            return

        _update_plot(det_data[new[0]])

    def scan_table_source_callback(_attr, _old, _new):
        _update_preview()

    scan_table_source = ColumnDataSource(
        dict(scan=[], hkl=[], fit=[], export=[]))
    scan_table_source.on_change("data", scan_table_source_callback)

    scan_table = DataTable(
        source=scan_table_source,
        columns=[
            TableColumn(field="scan", title="Scan", width=50),
            TableColumn(field="hkl", title="hkl", width=100),
            TableColumn(field="fit", title="Fit", width=50),
            TableColumn(field="export",
                        title="Export",
                        editor=CheckboxEditor(),
                        width=50),
        ],
        width=310,  # +60 because of the index column
        height=350,
        autosize_mode="none",
        editable=True,
    )

    scan_table_source.selected.on_change("indices", scan_table_select_callback)

    def _get_selected_scan():
        return det_data[scan_table_source.selected.indices[0]]

    merge_from_select = Select(title="scan:", width=145)

    def merge_button_callback():
        scan_into = _get_selected_scan()
        scan_from = det_data[int(merge_from_select.value)]

        if scan_into is scan_from:
            print("WARNING: Selected scans for merging are identical")
            return

        pyzebra.merge_scans(scan_into, scan_from)
        _update_plot(_get_selected_scan())

    merge_button = Button(label="Merge into current", width=145)
    merge_button.on_click(merge_button_callback)

    def restore_button_callback():
        pyzebra.restore_scan(_get_selected_scan())
        _update_plot(_get_selected_scan())

    restore_button = Button(label="Restore scan", width=145)
    restore_button.on_click(restore_button_callback)

    def fit_from_spinner_callback(_attr, _old, new):
        fit_from_span.location = new

    fit_from_spinner = Spinner(title="Fit from:", width=145)
    fit_from_spinner.on_change("value", fit_from_spinner_callback)

    def fit_to_spinner_callback(_attr, _old, new):
        fit_to_span.location = new

    fit_to_spinner = Spinner(title="to:", width=145)
    fit_to_spinner.on_change("value", fit_to_spinner_callback)

    def fitparams_add_dropdown_callback(click):
        # bokeh requires (str, str) for MultiSelect options
        new_tag = f"{click.item}-{fitparams_select.tags[0]}"
        fitparams_select.options.append((new_tag, click.item))
        fit_params[new_tag] = fitparams_factory(click.item)
        fitparams_select.tags[0] += 1

    fitparams_add_dropdown = Dropdown(
        label="Add fit function",
        menu=[
            ("Linear", "linear"),
            ("Gaussian", "gaussian"),
            ("Voigt", "voigt"),
            ("Pseudo Voigt", "pvoigt"),
            # ("Pseudo Voigt1", "pseudovoigt1"),
        ],
        width=145,
    )
    fitparams_add_dropdown.on_click(fitparams_add_dropdown_callback)

    def fitparams_select_callback(_attr, old, new):
        # Avoid selection of multiple indicies (via Shift+Click or Ctrl+Click)
        if len(new) > 1:
            # drop selection to the previous one
            fitparams_select.value = old
            return

        if len(old) > 1:
            # skip unnecessary update caused by selection drop
            return

        if new:
            fitparams_table_source.data.update(fit_params[new[0]])
        else:
            fitparams_table_source.data.update(
                dict(param=[], value=[], vary=[], min=[], max=[]))

    fitparams_select = MultiSelect(options=[], height=120, width=145)
    fitparams_select.tags = [0]
    fitparams_select.on_change("value", fitparams_select_callback)

    def fitparams_remove_button_callback():
        if fitparams_select.value:
            sel_tag = fitparams_select.value[0]
            del fit_params[sel_tag]
            for elem in fitparams_select.options:
                if elem[0] == sel_tag:
                    fitparams_select.options.remove(elem)
                    break

            fitparams_select.value = []

    fitparams_remove_button = Button(label="Remove fit function", width=145)
    fitparams_remove_button.on_click(fitparams_remove_button_callback)

    def fitparams_factory(function):
        if function == "linear":
            params = ["slope", "intercept"]
        elif function == "gaussian":
            params = ["amplitude", "center", "sigma"]
        elif function == "voigt":
            params = ["amplitude", "center", "sigma", "gamma"]
        elif function == "pvoigt":
            params = ["amplitude", "center", "sigma", "fraction"]
        elif function == "pseudovoigt1":
            params = ["amplitude", "center", "g_sigma", "l_sigma", "fraction"]
        else:
            raise ValueError("Unknown fit function")

        n = len(params)
        fitparams = dict(
            param=params,
            value=[None] * n,
            vary=[True] * n,
            min=[None] * n,
            max=[None] * n,
        )

        if function == "linear":
            fitparams["value"] = [0, 1]
            fitparams["vary"] = [False, True]
            fitparams["min"] = [None, 0]

        elif function == "gaussian":
            fitparams["min"] = [0, None, None]

        return fitparams

    fitparams_table_source = ColumnDataSource(
        dict(param=[], value=[], vary=[], min=[], max=[]))
    fitparams_table = DataTable(
        source=fitparams_table_source,
        columns=[
            TableColumn(field="param", title="Parameter"),
            TableColumn(field="value", title="Value", editor=NumberEditor()),
            TableColumn(field="vary", title="Vary", editor=CheckboxEditor()),
            TableColumn(field="min", title="Min", editor=NumberEditor()),
            TableColumn(field="max", title="Max", editor=NumberEditor()),
        ],
        height=200,
        width=350,
        index_position=None,
        editable=True,
        auto_edit=True,
    )

    # start with `background` and `gauss` fit functions added
    fitparams_add_dropdown_callback(types.SimpleNamespace(item="linear"))
    fitparams_add_dropdown_callback(types.SimpleNamespace(item="gaussian"))
    fitparams_select.value = ["gaussian-1"]  # add selection to gauss

    fit_output_textinput = TextAreaInput(title="Fit results:",
                                         width=750,
                                         height=200)

    def proc_all_button_callback():
        for scan, export in zip(det_data, scan_table_source.data["export"]):
            if export:
                pyzebra.fit_scan(scan,
                                 fit_params,
                                 fit_from=fit_from_spinner.value,
                                 fit_to=fit_to_spinner.value)
                pyzebra.get_area(
                    scan,
                    area_method=AREA_METHODS[area_method_radiobutton.active],
                    lorentz=lorentz_checkbox.active,
                )

        _update_plot(_get_selected_scan())
        _update_table()

    proc_all_button = Button(label="Process All",
                             button_type="primary",
                             width=145)
    proc_all_button.on_click(proc_all_button_callback)

    def proc_button_callback():
        scan = _get_selected_scan()
        pyzebra.fit_scan(scan,
                         fit_params,
                         fit_from=fit_from_spinner.value,
                         fit_to=fit_to_spinner.value)
        pyzebra.get_area(
            scan,
            area_method=AREA_METHODS[area_method_radiobutton.active],
            lorentz=lorentz_checkbox.active,
        )

        _update_plot(scan)
        _update_table()

    proc_button = Button(label="Process Current", width=145)
    proc_button.on_click(proc_button_callback)

    area_method_div = Div(text="Intensity:", margin=(5, 5, 0, 5))
    area_method_radiobutton = RadioGroup(labels=["Function", "Area"],
                                         active=0,
                                         width=145)

    lorentz_checkbox = CheckboxGroup(labels=["Lorentz Correction"],
                                     width=145,
                                     margin=(13, 5, 5, 5))

    export_preview_textinput = TextAreaInput(title="Export file preview:",
                                             width=500,
                                             height=400)

    def _update_preview():
        with tempfile.TemporaryDirectory() as temp_dir:
            temp_file = temp_dir + "/temp"
            export_data = []
            for s, export in zip(det_data, scan_table_source.data["export"]):
                if export:
                    export_data.append(s)

            pyzebra.export_1D(
                export_data,
                temp_file,
                export_target_select.value,
                hkl_precision=int(hkl_precision_select.value),
            )

            exported_content = ""
            file_content = []
            for ext in EXPORT_TARGETS[export_target_select.value]:
                fname = temp_file + ext
                if os.path.isfile(fname):
                    with open(fname) as f:
                        content = f.read()
                        exported_content += f"{ext} file:\n" + content
                else:
                    content = ""
                file_content.append(content)

            js_data.data.update(content=file_content)
            export_preview_textinput.value = exported_content

    def export_target_select_callback(_attr, _old, new):
        js_data.data.update(ext=EXPORT_TARGETS[new])
        _update_preview()

    export_target_select = Select(title="Export target:",
                                  options=list(EXPORT_TARGETS.keys()),
                                  value="fullprof",
                                  width=80)
    export_target_select.on_change("value", export_target_select_callback)
    js_data.data.update(ext=EXPORT_TARGETS[export_target_select.value])

    def hkl_precision_select_callback(_attr, _old, _new):
        _update_preview()

    hkl_precision_select = Select(title="hkl precision:",
                                  options=["2", "3", "4"],
                                  value="2",
                                  width=80)
    hkl_precision_select.on_change("value", hkl_precision_select_callback)

    save_button = Button(label="Download File(s)",
                         button_type="success",
                         width=200)
    save_button.js_on_click(
        CustomJS(args={"js_data": js_data}, code=javaScript))

    fitpeak_controls = row(
        column(fitparams_add_dropdown, fitparams_select,
               fitparams_remove_button),
        fitparams_table,
        Spacer(width=20),
        column(fit_from_spinner, lorentz_checkbox, area_method_div,
               area_method_radiobutton),
        column(fit_to_spinner, proc_button, proc_all_button),
    )

    scan_layout = column(
        scan_table,
        row(monitor_spinner, column(Spacer(height=19), restore_button)),
        row(column(Spacer(height=19), merge_button), merge_from_select),
    )

    import_layout = column(
        proposal_textinput,
        file_select,
        row(file_open_button, file_append_button),
        upload_div,
        upload_button,
        append_upload_div,
        append_upload_button,
    )

    export_layout = column(
        export_preview_textinput,
        row(export_target_select, hkl_precision_select,
            column(Spacer(height=19), row(save_button))),
    )

    tab_layout = column(
        row(import_layout, scan_layout, plot, Spacer(width=30), export_layout),
        row(fitpeak_controls, fit_output_textinput),
    )

    return Panel(child=tab_layout, title="ccl integrate")
def create(palm):
    doc = curdoc()

    # Calibration averaged waveforms per photon energy
    waveform_plot = Plot(
        title=Title(text="eTOF calibration waveforms"),
        x_range=DataRange1d(),
        y_range=DataRange1d(),
        plot_height=760,
        plot_width=PLOT_CANVAS_WIDTH,
        toolbar_location="right",
    )

    # ---- tools
    waveform_plot.toolbar.logo = None
    waveform_plot_hovertool = HoverTool(
        tooltips=[("energy, eV", "@en"), ("eTOF bin", "$x{0.}")])

    waveform_plot.add_tools(PanTool(), BoxZoomTool(), WheelZoomTool(),
                            ResetTool(), waveform_plot_hovertool)

    # ---- axes
    waveform_plot.add_layout(LinearAxis(axis_label="eTOF time bin"),
                             place="below")
    waveform_plot.add_layout(LinearAxis(axis_label="Intensity",
                                        major_label_orientation="vertical"),
                             place="left")

    # ---- grid lines
    waveform_plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    waveform_plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- multiline glyphs
    waveform_ref_source = ColumnDataSource(dict(xs=[], ys=[], en=[]))
    waveform_ref_multiline = waveform_plot.add_glyph(
        waveform_ref_source, MultiLine(xs="xs", ys="ys", line_color="blue"))

    waveform_str_source = ColumnDataSource(dict(xs=[], ys=[], en=[]))
    waveform_str_multiline = waveform_plot.add_glyph(
        waveform_str_source, MultiLine(xs="xs", ys="ys", line_color="red"))

    # ---- legend
    waveform_plot.add_layout(
        Legend(items=[(
            "reference",
            [waveform_ref_multiline]), ("streaked",
                                        [waveform_str_multiline])]))
    waveform_plot.legend.click_policy = "hide"

    # ---- vertical spans
    photon_peak_ref_span = Span(location=0,
                                dimension="height",
                                line_dash="dashed",
                                line_color="blue")
    photon_peak_str_span = Span(location=0,
                                dimension="height",
                                line_dash="dashed",
                                line_color="red")
    waveform_plot.add_layout(photon_peak_ref_span)
    waveform_plot.add_layout(photon_peak_str_span)

    # Calibration fit plot
    fit_plot = Plot(
        title=Title(text="eTOF calibration fit"),
        x_range=DataRange1d(),
        y_range=DataRange1d(),
        plot_height=PLOT_CANVAS_HEIGHT,
        plot_width=PLOT_CANVAS_WIDTH,
        toolbar_location="right",
    )

    # ---- tools
    fit_plot.toolbar.logo = None
    fit_plot.add_tools(PanTool(), BoxZoomTool(), WheelZoomTool(), ResetTool())

    # ---- axes
    fit_plot.add_layout(LinearAxis(axis_label="Photoelectron peak shift"),
                        place="below")
    fit_plot.add_layout(LinearAxis(axis_label="Photon energy, eV",
                                   major_label_orientation="vertical"),
                        place="left")

    # ---- grid lines
    fit_plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    fit_plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- circle glyphs
    fit_ref_circle_source = ColumnDataSource(dict(x=[], y=[]))
    fit_ref_circle = fit_plot.add_glyph(
        fit_ref_circle_source, Circle(x="x", y="y", line_color="blue"))
    fit_str_circle_source = ColumnDataSource(dict(x=[], y=[]))
    fit_str_circle = fit_plot.add_glyph(fit_str_circle_source,
                                        Circle(x="x", y="y", line_color="red"))

    # ---- line glyphs
    fit_ref_line_source = ColumnDataSource(dict(x=[], y=[]))
    fit_ref_line = fit_plot.add_glyph(fit_ref_line_source,
                                      Line(x="x", y="y", line_color="blue"))
    fit_str_line_source = ColumnDataSource(dict(x=[], y=[]))
    fit_str_line = fit_plot.add_glyph(fit_str_line_source,
                                      Line(x="x", y="y", line_color="red"))

    # ---- legend
    fit_plot.add_layout(
        Legend(items=[
            ("reference", [fit_ref_circle, fit_ref_line]),
            ("streaked", [fit_str_circle, fit_str_line]),
        ]))
    fit_plot.legend.click_policy = "hide"

    # Calibration results datatables
    def datatable_ref_source_callback(_attr, _old_value, new_value):
        for en, ps, use in zip(new_value["energy"], new_value["peak_pos_ref"],
                               new_value["use_in_fit"]):
            palm.etofs["0"].calib_data.loc[
                en, "calib_tpeak"] = ps if ps != "NaN" else np.nan
            palm.etofs["0"].calib_data.loc[en, "use_in_fit"] = use

        calib_res = {}
        for etof_key in palm.etofs:
            calib_res[etof_key] = palm.etofs[etof_key].fit_calibration_curve()
        update_calibration_plot(calib_res)

    datatable_ref_source = ColumnDataSource(
        dict(energy=["", "", ""],
             peak_pos_ref=["", "", ""],
             use_in_fit=[True, True, True]))
    datatable_ref_source.on_change("data", datatable_ref_source_callback)

    datatable_ref = DataTable(
        source=datatable_ref_source,
        columns=[
            TableColumn(field="energy",
                        title="Photon Energy, eV",
                        editor=IntEditor()),
            TableColumn(field="peak_pos_ref",
                        title="Reference Peak",
                        editor=IntEditor()),
            TableColumn(field="use_in_fit",
                        title=" ",
                        editor=CheckboxEditor(),
                        width=80),
        ],
        index_position=None,
        editable=True,
        height=300,
        width=250,
    )

    def datatable_str_source_callback(_attr, _old_value, new_value):
        for en, ps, use in zip(new_value["energy"], new_value["peak_pos_str"],
                               new_value["use_in_fit"]):
            palm.etofs["1"].calib_data.loc[
                en, "calib_tpeak"] = ps if ps != "NaN" else np.nan
            palm.etofs["1"].calib_data.loc[en, "use_in_fit"] = use

        calib_res = {}
        for etof_key in palm.etofs:
            calib_res[etof_key] = palm.etofs[etof_key].fit_calibration_curve()
        update_calibration_plot(calib_res)

    datatable_str_source = ColumnDataSource(
        dict(energy=["", "", ""],
             peak_pos_str=["", "", ""],
             use_in_fit=[True, True, True]))
    datatable_str_source.on_change("data", datatable_str_source_callback)

    datatable_str = DataTable(
        source=datatable_str_source,
        columns=[
            TableColumn(field="energy",
                        title="Photon Energy, eV",
                        editor=IntEditor()),
            TableColumn(field="peak_pos_str",
                        title="Streaked Peak",
                        editor=IntEditor()),
            TableColumn(field="use_in_fit",
                        title=" ",
                        editor=CheckboxEditor(),
                        width=80),
        ],
        index_position=None,
        editable=True,
        height=350,
        width=250,
    )

    # eTOF calibration folder path text input
    def path_textinput_callback(_attr, _old_value, _new_value):
        path_periodic_update()
        update_load_dropdown_menu()

    path_textinput = TextInput(title="eTOF calibration path:",
                               value=os.path.join(os.path.expanduser("~")),
                               width=510)
    path_textinput.on_change("value", path_textinput_callback)

    # eTOF calibration eco scans dropdown
    def scans_dropdown_callback(event):
        scans_dropdown.label = event.item

    scans_dropdown = Dropdown(label="ECO scans",
                              button_type="default",
                              menu=[])
    scans_dropdown.on_click(scans_dropdown_callback)

    # ---- etof scans periodic update
    def path_periodic_update():
        new_menu = []
        if os.path.isdir(path_textinput.value):
            for entry in os.scandir(path_textinput.value):
                if entry.is_file() and entry.name.endswith(".json"):
                    new_menu.append((entry.name, entry.name))
        scans_dropdown.menu = sorted(new_menu, reverse=True)

    doc.add_periodic_callback(path_periodic_update, 5000)

    path_tab = Panel(child=column(
        path_textinput,
        scans_dropdown,
    ),
                     title="Path")

    upload_div = Div(text="Upload ECO scan (top) and all hdf5 files (bottom):")

    # ECO scan upload FileInput
    def eco_fileinput_callback(_attr, _old, new):
        with io.BytesIO(base64.b64decode(new)) as eco_scan:
            data = json.load(eco_scan)
            print(data)

    eco_fileinput = FileInput(accept=".json", disabled=True)
    eco_fileinput.on_change("value", eco_fileinput_callback)

    # HDF5 upload FileInput
    def hdf5_fileinput_callback(_attr, _old, new):
        for base64_str in new:
            with io.BytesIO(base64.b64decode(base64_str)) as hdf5_file:
                with h5py.File(hdf5_file, "r") as h5f:
                    print(h5f.keys())

    hdf5_fileinput = FileInput(accept=".hdf5,.h5",
                               multiple=True,
                               disabled=True)
    hdf5_fileinput.on_change("value", hdf5_fileinput_callback)

    upload_tab = Panel(child=column(upload_div, eco_fileinput, hdf5_fileinput),
                       title="Upload")

    # Calibrate button
    def calibrate_button_callback():
        try:
            palm.calibrate_etof_eco(eco_scan_filename=os.path.join(
                path_textinput.value, scans_dropdown.label))
        except Exception:
            palm.calibrate_etof(folder_name=path_textinput.value)

        datatable_ref_source.data.update(
            energy=palm.etofs["0"].calib_data.index.tolist(),
            peak_pos_ref=palm.etofs["0"].calib_data["calib_tpeak"].tolist(),
            use_in_fit=palm.etofs["0"].calib_data["use_in_fit"].tolist(),
        )

        datatable_str_source.data.update(
            energy=palm.etofs["0"].calib_data.index.tolist(),
            peak_pos_str=palm.etofs["1"].calib_data["calib_tpeak"].tolist(),
            use_in_fit=palm.etofs["1"].calib_data["use_in_fit"].tolist(),
        )

    def update_calibration_plot(calib_res):
        etof_ref = palm.etofs["0"]
        etof_str = palm.etofs["1"]

        shift_val = 0
        etof_ref_wf_shifted = []
        etof_str_wf_shifted = []
        for wf_ref, wf_str in zip(etof_ref.calib_data["waveform"],
                                  etof_str.calib_data["waveform"]):
            shift_val -= max(wf_ref.max(), wf_str.max())
            etof_ref_wf_shifted.append(wf_ref + shift_val)
            etof_str_wf_shifted.append(wf_str + shift_val)

        waveform_ref_source.data.update(
            xs=len(etof_ref.calib_data) *
            [list(range(etof_ref.internal_time_bins))],
            ys=etof_ref_wf_shifted,
            en=etof_ref.calib_data.index.tolist(),
        )

        waveform_str_source.data.update(
            xs=len(etof_str.calib_data) *
            [list(range(etof_str.internal_time_bins))],
            ys=etof_str_wf_shifted,
            en=etof_str.calib_data.index.tolist(),
        )

        photon_peak_ref_span.location = etof_ref.calib_t0
        photon_peak_str_span.location = etof_str.calib_t0

        def plot_fit(time, calib_a, calib_b):
            time_fit = np.linspace(np.nanmin(time), np.nanmax(time), 100)
            en_fit = (calib_a / time_fit)**2 + calib_b
            return time_fit, en_fit

        def update_plot(calib_results, circle, line):
            (a, c), x, y = calib_results
            x_fit, y_fit = plot_fit(x, a, c)
            circle.data.update(x=x, y=y)
            line.data.update(x=x_fit, y=y_fit)

        update_plot(calib_res["0"], fit_ref_circle_source, fit_ref_line_source)
        update_plot(calib_res["1"], fit_str_circle_source, fit_str_line_source)

        calib_const_div.text = f"""
        a_str = {etof_str.calib_a:.2f}<br>
        b_str = {etof_str.calib_b:.2f}<br>
        <br>
        a_ref = {etof_ref.calib_a:.2f}<br>
        b_ref = {etof_ref.calib_b:.2f}
        """

    calibrate_button = Button(label="Calibrate eTOF",
                              button_type="default",
                              width=250)
    calibrate_button.on_click(calibrate_button_callback)

    # Photon peak noise threshold value text input
    def phot_peak_noise_thr_spinner_callback(_attr, old_value, new_value):
        if new_value > 0:
            for etof in palm.etofs.values():
                etof.photon_peak_noise_thr = new_value
        else:
            phot_peak_noise_thr_spinner.value = old_value

    phot_peak_noise_thr_spinner = Spinner(title="Photon peak noise threshold:",
                                          value=1,
                                          step=0.1)
    phot_peak_noise_thr_spinner.on_change(
        "value", phot_peak_noise_thr_spinner_callback)

    # Electron peak noise threshold value text input
    def el_peak_noise_thr_spinner_callback(_attr, old_value, new_value):
        if new_value > 0:
            for etof in palm.etofs.values():
                etof.electron_peak_noise_thr = new_value
        else:
            el_peak_noise_thr_spinner.value = old_value

    el_peak_noise_thr_spinner = Spinner(title="Electron peak noise threshold:",
                                        value=10,
                                        step=0.1)
    el_peak_noise_thr_spinner.on_change("value",
                                        el_peak_noise_thr_spinner_callback)

    # Save calibration button
    def save_button_callback():
        palm.save_etof_calib(path=path_textinput.value)
        update_load_dropdown_menu()

    save_button = Button(label="Save", button_type="default", width=250)
    save_button.on_click(save_button_callback)

    # Load calibration button
    def load_dropdown_callback(event):
        new_value = event.item
        if new_value:
            palm.load_etof_calib(os.path.join(path_textinput.value, new_value))

            datatable_ref_source.data.update(
                energy=palm.etofs["0"].calib_data.index.tolist(),
                peak_pos_ref=palm.etofs["0"].calib_data["calib_tpeak"].tolist(
                ),
                use_in_fit=palm.etofs["0"].calib_data["use_in_fit"].tolist(),
            )

            datatable_str_source.data.update(
                energy=palm.etofs["0"].calib_data.index.tolist(),
                peak_pos_str=palm.etofs["1"].calib_data["calib_tpeak"].tolist(
                ),
                use_in_fit=palm.etofs["1"].calib_data["use_in_fit"].tolist(),
            )

    def update_load_dropdown_menu():
        new_menu = []
        calib_file_ext = ".palm_etof"
        if os.path.isdir(path_textinput.value):
            for entry in os.scandir(path_textinput.value):
                if entry.is_file() and entry.name.endswith((calib_file_ext)):
                    new_menu.append(
                        (entry.name[:-len(calib_file_ext)], entry.name))
            load_dropdown.button_type = "default"
            load_dropdown.menu = sorted(new_menu, reverse=True)
        else:
            load_dropdown.button_type = "danger"
            load_dropdown.menu = new_menu

    doc.add_next_tick_callback(update_load_dropdown_menu)
    doc.add_periodic_callback(update_load_dropdown_menu, 5000)

    load_dropdown = Dropdown(label="Load", menu=[], width=250)
    load_dropdown.on_click(load_dropdown_callback)

    # eTOF fitting equation
    fit_eq_div = Div(
        text="""Fitting equation:<br><br><img src="/palm/static/5euwuy.gif">"""
    )

    # Calibration constants
    calib_const_div = Div(text=f"""
        a_str = {0}<br>
        b_str = {0}<br>
        <br>
        a_ref = {0}<br>
        b_ref = {0}
        """)

    # assemble
    tab_layout = column(
        row(
            column(waveform_plot, fit_plot),
            Spacer(width=30),
            column(
                Tabs(tabs=[path_tab, upload_tab]),
                calibrate_button,
                phot_peak_noise_thr_spinner,
                el_peak_noise_thr_spinner,
                row(save_button, load_dropdown),
                row(datatable_ref, datatable_str),
                calib_const_div,
                fit_eq_div,
            ),
        ))

    return Panel(child=tab_layout, title="eTOF Calibration")
Beispiel #15
0
class SupervisableDataset(Loggable):
    """
    ???+ note "Feature-agnostic class for a dataset open to supervision."

        Keeping a DataFrame form and a list-of-dicts ("dictl") form, with the intention that

        - the DataFrame form supports most kinds of operations;
        - the list-of-dicts form could be useful for manipulations outside the scope of pandas;
        - synchronization between the two forms should be called sparingly.
    """

    # 'scratch': intended to be directly editable by other objects, i.e. Explorers
    # labels will be stored but not used for information in hover itself
    SCRATCH_SUBSETS = tuple(["raw"])

    # non-'scratch': intended to be read-only outside of the class
    # 'public': labels will be considered as part of the classification task and will be used for built-in supervision
    PUBLIC_SUBSETS = tuple(["train", "dev"])
    # 'private': labels will be considered as part of the classification task and will NOT be used for supervision
    PRIVATE_SUBSETS = tuple(["test"])

    FEATURE_KEY = "feature"

    def __init__(
        self,
        raw_dictl,
        train_dictl=None,
        dev_dictl=None,
        test_dictl=None,
        feature_key="feature",
        label_key="label",
    ):
        """
        ???+ note "Create (1) dictl and df forms and (2) the mapping between categorical and string labels."
            | Param         | Type   | Description                          |
            | :------------ | :----- | :----------------------------------- |
            | `raw_dictl`   | `list` | list of dicts holding the **to-be-supervised** raw data |
            | `train_dictl` | `list` | list of dicts holding any **supervised** train data |
            | `dev_dictl`   | `list` | list of dicts holding any **supervised** dev data   |
            | `test_dictl`  | `list` | list of dicts holding any **supervised** test data  |
            | `feature_key` | `str`  | the key for the feature in each piece of data |
            | `label_key`   | `str`  | the key for the `**str**` label in supervised data |
        """
        self._info("Initializing...")

        def dictl_transform(dictl, labels=True):
            """
            Burner function to transform the input list of dictionaries into standard format.
            """
            # edge case when dictl is empty or None
            if not dictl:
                return []

            # transform the feature and possibly the label
            key_transform = {feature_key: self.__class__.FEATURE_KEY}
            if labels:
                key_transform[label_key] = "label"

            def burner(d):
                """
                Burner function to transform a single dict.
                """
                if labels:
                    assert label_key in d, f"Expected dict key {label_key}"

                trans_d = {
                    key_transform.get(_k, _k): _v
                    for _k, _v in d.items()
                }

                if not labels:
                    trans_d["label"] = module_config.ABSTAIN_DECODED

                return trans_d

            return [burner(_d) for _d in dictl]

        self.dictls = {
            "raw": dictl_transform(raw_dictl, labels=False),
            "train": dictl_transform(train_dictl),
            "dev": dictl_transform(dev_dictl),
            "test": dictl_transform(test_dictl),
        }

        self.synchronize_dictl_to_df()
        self.df_deduplicate()
        self.synchronize_df_to_dictl()
        self.setup_widgets()
        # self.setup_label_coding() # redundant if setup_pop_table() immediately calls this again
        self.setup_file_export()
        self.setup_pop_table(width_policy="fit", height_policy="fit")
        self.setup_sel_table(width_policy="fit", height_policy="fit")
        self._good(f"{self.__class__.__name__}: finished initialization.")

    def copy(self, use_df=True):
        """
        ???+ note "Create another instance, copying over the data entries."
            | Param    | Type   | Description                          |
            | :------- | :----- | :----------------------------------- |
            | `use_df` | `bool` | whether to use the df or dictl form  |
        """
        if use_df:
            self.synchronize_df_to_dictl()
        return self.__class__(
            raw_dictl=self.dictls["raw"],
            train_dictl=self.dictls["train"],
            dev_dictl=self.dictls["dev"],
            test_dictl=self.dictls["test"],
            feature_key=self.__class__.FEATURE_KEY,
            label_key="label",
        )

    def to_pandas(self, use_df=True):
        """
        ???+ note "Export to a pandas DataFrame."
            | Param    | Type   | Description                          |
            | :------- | :----- | :----------------------------------- |
            | `use_df` | `bool` | whether to use the df or dictl form  |
        """
        if not use_df:
            self.synchronize_dictl_to_df()
        dfs = []
        for _subset in ["raw", "train", "dev", "test"]:
            _df = self.dfs[_subset].copy()
            _df[DATASET_SUBSET_FIELD] = _subset
            dfs.append(_df)

        return pd.concat(dfs, axis=0)

    @classmethod
    def from_pandas(cls, df, **kwargs):
        """
        ???+ note "Import from a pandas DataFrame."
            | Param    | Type   | Description                          |
            | :------- | :----- | :----------------------------------- |
            | `df` | `DataFrame` | with a "SUBSET" field dividing subsets |
        """
        SUBSETS = cls.SCRATCH_SUBSETS + cls.PUBLIC_SUBSETS + cls.PRIVATE_SUBSETS

        if DATASET_SUBSET_FIELD not in df.columns:
            raise ValueError(
                f"Expecting column '{DATASET_SUBSET_FIELD}' in the DataFrame which takes values from {SUBSETS}"
            )

        dictls = {}
        for _subset in ["raw", "train", "dev", "test"]:
            _sub_df = df[df[DATASET_SUBSET_FIELD] == _subset]
            dictls[_subset] = _sub_df.to_dict(orient="records")

        return cls(
            raw_dictl=dictls["raw"],
            train_dictl=dictls["train"],
            dev_dictl=dictls["dev"],
            test_dictl=dictls["test"],
            **kwargs,
        )

    def setup_widgets(self):
        """
        ???+ note "Create `bokeh` widgets for interactive data management."
        """
        self.update_pusher = Button(label="Push",
                                    button_type="success",
                                    height_policy="fit",
                                    width_policy="min")
        self.data_committer = Dropdown(
            label="Commit",
            button_type="warning",
            menu=[
                *self.__class__.PUBLIC_SUBSETS, *self.__class__.PRIVATE_SUBSETS
            ],
            height_policy="fit",
            width_policy="min",
        )
        self.dedup_trigger = Button(
            label="Dedup",
            button_type="warning",
            height_policy="fit",
            width_policy="min",
        )
        self.selection_viewer = Button(
            label="View Selected",
            button_type="primary",
            height_policy="fit",
            width_policy="min",
        )

        def commit_base_callback():
            """
            COMMIT creates cross-duplicates between subsets.

            - PUSH shall be blocked until DEDUP is executed.
            """
            self.dedup_trigger.disabled = False
            self.update_pusher.disabled = True

        def dedup_base_callback():
            """
            DEDUP re-creates dfs with different indices than before.

            - COMMIT shall be blocked until PUSH is executed.
            """
            self.update_pusher.disabled = False
            self.data_committer.disabled = True
            self.df_deduplicate()

        def push_base_callback():
            """
            PUSH enforces df consistency with all linked explorers.

            - DEDUP could be blocked because it stays trivial until COMMIT is executed.
            """
            self.data_committer.disabled = False
            self.dedup_trigger.disabled = True

        self.update_pusher.on_click(push_base_callback)
        self.data_committer.on_click(commit_base_callback)
        self.dedup_trigger.on_click(dedup_base_callback)

        self.help_div = dataset_help_widget()

    def view(self):
        """
        ???+ note "Defines the layout of `bokeh` objects when visualized."
        """
        # local import to avoid naming confusion/conflicts
        from bokeh.layouts import row, column

        return column(
            self.help_div,
            row(
                self.update_pusher,
                self.data_committer,
                self.dedup_trigger,
                self.selection_viewer,
                self.file_exporter,
            ),
            self.pop_table,
            self.sel_table,
        )

    def subscribe_update_push(self, explorer, subset_mapping):
        """
        ???+ note "Enable pushing updated DataFrames to explorers that depend on them."
            | Param            | Type   | Description                            |
            | :--------------- | :----- | :------------------------------------- |
            | `explorer`       | `BokehBaseExplorer` | the explorer to register  |
            | `subset_mapping` | `dict` | `dataset` -> `explorer` subset mapping |

            Note: the reason we need this is due to `self.dfs[key] = ...`-like assignments. If DF operations were all in-place, then the explorers could directly access the updates through their `self.dfs` references.
        """
        # local import to avoid import cycles
        from hover.core.explorer.base import BokehBaseExplorer

        assert isinstance(explorer, BokehBaseExplorer)

        def callback_push():
            df_dict = {_v: self.dfs[_k] for _k, _v in subset_mapping.items()}
            explorer._setup_dfs(df_dict)
            explorer._update_sources()

        self.update_pusher.on_click(callback_push)
        self._good(
            f"Subscribed {explorer.__class__.__name__} to dataset pushes: {subset_mapping}"
        )

    def subscribe_data_commit(self, explorer, subset_mapping):
        """
        ???+ note "Enable committing data across subsets, specified by a selection in an explorer and a dropdown widget of the dataset."
            | Param            | Type   | Description                            |
            | :--------------- | :----- | :------------------------------------- |
            | `explorer`       | `BokehBaseExplorer` | the explorer to register  |
            | `subset_mapping` | `dict` | `dataset` -> `explorer` subset mapping |
        """
        def callback_commit(event):
            for sub_k, sub_v in subset_mapping.items():
                sub_to = event.item
                selected_idx = explorer.sources[sub_v].selected.indices
                if not selected_idx:
                    self._warn(
                        f"Attempting data commit: did not select any data points in subset {sub_v}."
                    )
                    return

                # take selected slice, ignoring ABSTAIN'ed rows
                # CAUTION: applying selected_idx from explorer.source to self.df
                #     this assumes that the source and the df have consistent entries.
                # Consider this:
                #    keep_cols = self.dfs[sub_k].columns
                #    sel_slice = explorer.dfs[sub_v].iloc[selected_idx][keep_cols]
                sel_slice = self.dfs[sub_k].iloc[selected_idx]
                valid_slice = sel_slice[
                    sel_slice["label"] != module_config.ABSTAIN_DECODED]

                # concat to the end and do some accounting
                size_before = self.dfs[sub_to].shape[0]
                self.dfs[sub_to] = pd.concat(
                    [self.dfs[sub_to], valid_slice],
                    axis=0,
                    sort=False,
                    ignore_index=True,
                )
                size_mid = self.dfs[sub_to].shape[0]
                self.dfs[sub_to].drop_duplicates(
                    subset=[self.__class__.FEATURE_KEY],
                    keep="last",
                    inplace=True)
                size_after = self.dfs[sub_to].shape[0]

                self._info(
                    f"Committed {valid_slice.shape[0]} (valid out of {sel_slice.shape[0]} selected) entries from {sub_k} to {sub_to} ({size_before} -> {size_after} with {size_mid-size_after} overwrites)."
                )
            # chain another callback
            self._callback_update_population()

        self.data_committer.on_click(callback_commit)
        self._good(
            f"Subscribed {explorer.__class__.__name__} to dataset commits: {subset_mapping}"
        )

    def subscribe_selection_view(self, explorer, subsets):
        """
        ???+ note "Enable viewing groups of data entries, specified by a selection in an explorer."
            | Param            | Type   | Description                            |
            | :--------------- | :----- | :------------------------------------- |
            | `explorer`       | `BokehBaseExplorer` | the explorer to register  |
            | `subsets`        | `list` | subset selections to consider          |
        """
        assert (isinstance(subsets, list)
                and len(subsets) > 0), "Expected a non-empty list of subsets"

        def callback_view():
            sel_slices = []
            for subset in subsets:
                selected_idx = explorer.sources[subset].selected.indices
                sub_slice = explorer.dfs[subset].iloc[selected_idx]
                sel_slices.append(sub_slice)

            selected = pd.concat(sel_slices, axis=0)

            # replace this with an actual display (and analysis)
            self._callback_update_selection(selected)

        self.selection_viewer.on_click(callback_view)
        self._good(
            f"Subscribed {explorer.__class__.__name__} to selection view: {subsets}"
        )

    def setup_label_coding(self, verbose=True, debug=False):
        """
        ???+ note "Auto-determine labels in the dataset, then create encoder/decoder in lexical order."
            Add `"ABSTAIN"` as a no-label placeholder which gets ignored categorically.

            | Param     | Type   | Description                        |
            | :-------- | :----- | :--------------------------------- |
            | `verbose` | `bool` | whether to log verbosely           |
            | `debug`   | `bool` | whether to enable label validation |
        """
        all_labels = set()
        for _key in [
                *self.__class__.PUBLIC_SUBSETS, *self.__class__.PRIVATE_SUBSETS
        ]:
            _df = self.dfs[_key]
            _found_labels = set(_df["label"].tolist())
            all_labels = all_labels.union(_found_labels)

        # exclude ABSTAIN from self.classes, but include it in the encoding
        all_labels.discard(module_config.ABSTAIN_DECODED)
        self.classes = sorted(all_labels)
        self.label_encoder = {
            **{_label: _i
               for _i, _label in enumerate(self.classes)},
            module_config.ABSTAIN_DECODED: module_config.ABSTAIN_ENCODED,
        }
        self.label_decoder = {_v: _k for _k, _v in self.label_encoder.items()}

        if verbose:
            self._good(
                f"Set up label encoder/decoder with {len(self.classes)} classes."
            )
        if debug:
            self.validate_labels()

    def validate_labels(self, raise_exception=True):
        """
        ???+ note "Assert that every label is in the encoder."

            | Param             | Type   | Description                         |
            | :---------------- | :----- | :---------------------------------- |
            | `raise_exception` | `bool` | whether to raise errors when failed |
        """
        for _key in [
                *self.__class__.PUBLIC_SUBSETS, *self.__class__.PRIVATE_SUBSETS
        ]:
            _invalid_indices = None
            assert "label" in self.dfs[_key].columns
            _mask = self.dfs[_key]["label"].apply(
                lambda x: x in self.label_encoder)
            _invalid_indices = np.where(_mask is False)[0].tolist()
            if _invalid_indices:
                self._fail(f"Subset {_key} has invalid labels:")
                self._print({self.dfs[_key].loc[_invalid_indices]})
                if raise_exception:
                    raise ValueError("invalid labels")

    def setup_file_export(self):
        self.file_exporter = Dropdown(
            label="Export",
            button_type="warning",
            menu=["Excel", "CSV", "JSON", "pickle"],
            height_policy="fit",
            width_policy="min",
        )

        def callback_export(event, path_root=None):
            """
            A callback on clicking the 'self.annotator_export' button.
            Saves the dataframe to a pickle.
            """
            export_format = event.item

            # auto-determine the export path root
            if path_root is None:
                timestamp = current_time("%Y%m%d%H%M%S")
                path_root = f"hover-dataset-export-{timestamp}"

            export_df = self.to_pandas(use_df=True)

            if export_format == "Excel":
                export_path = f"{path_root}.xlsx"
                export_df.to_excel(export_path, index=False)
            elif export_format == "CSV":
                export_path = f"{path_root}.csv"
                export_df.to_csv(export_path, index=False)
            elif export_format == "JSON":
                export_path = f"{path_root}.json"
                export_df.to_json(export_path, orient="records")
            elif export_format == "pickle":
                export_path = f"{path_root}.pkl"
                export_df.to_pickle(export_path)
            else:
                raise ValueError(f"Unexpected export format {export_format}")

            self._good(f"Saved DataFrame to {export_path}")

        # assign the callback, keeping its reference
        self._callback_export = callback_export
        self.file_exporter.on_click(self._callback_export)

    def setup_pop_table(self, **kwargs):
        """
        ???+ note "Set up a bokeh `DataTable` widget for monitoring subset data populations."

            | Param      | Type   | Description                  |
            | :--------- | :----- | :--------------------------- |
            | `**kwargs` |        | forwarded to the `DataTable` |
        """
        subsets = [
            *self.__class__.SCRATCH_SUBSETS,
            *self.__class__.PUBLIC_SUBSETS,
            *self.__class__.PRIVATE_SUBSETS,
        ]
        pop_source = ColumnDataSource(dict())
        pop_columns = [
            TableColumn(field="label", title="label"),
            *[
                TableColumn(field=f"count_{_subset}", title=_subset)
                for _subset in subsets
            ],
            TableColumn(
                field="color",
                title="color",
                formatter=HTMLTemplateFormatter(template=COLOR_GLYPH_TEMPLATE),
            ),
        ]
        self.pop_table = DataTable(source=pop_source,
                                   columns=pop_columns,
                                   **kwargs)

        def update_population():
            """
            Callback function.
            """
            # make sure that the label coding is correct
            self.setup_label_coding()

            # re-compute label population
            eff_labels = [module_config.ABSTAIN_DECODED, *self.classes]
            color_dict = auto_label_color(self.classes)
            eff_colors = [color_dict[_label] for _label in eff_labels]

            pop_data = dict(color=eff_colors, label=eff_labels)
            for _subset in subsets:
                _subpop = self.dfs[_subset]["label"].value_counts()
                pop_data[f"count_{_subset}"] = [
                    _subpop.get(_label, 0) for _label in eff_labels
                ]

            # push results to bokeh data source
            pop_source.data = pop_data

            self._good(
                f"Population updater: latest population with {len(self.classes)} classes."
            )

        update_population()
        self.dedup_trigger.on_click(update_population)

        # store the callback so that it can be referenced by other methods
        self._callback_update_population = update_population

    def setup_sel_table(self, **kwargs):
        """
        ???+ note "Set up a bokeh `DataTable` widget for viewing selected data points."

            | Param      | Type   | Description                  |
            | :--------- | :----- | :--------------------------- |
            | `**kwargs` |        | forwarded to the `DataTable` |
        """
        def auto_columns(df):
            return [TableColumn(field=_col, title=_col) for _col in df.columns]

        sel_source = ColumnDataSource(dict())
        sel_columns = auto_columns(self.dfs["train"])
        self.sel_table = DataTable(source=sel_source,
                                   columns=sel_columns,
                                   **kwargs)

        def update_selection(selected_df):
            """
            Callback function.
            """
            # push results to bokeh data source
            self.sel_table.columns = auto_columns(selected_df)
            sel_source.data = selected_df.to_dict(orient="list")

            self._good(
                f"Selection table: latest selection with {selected_df.shape[0]} entries."
            )

        self._callback_update_selection = update_selection

    def df_deduplicate(self):
        """
        ???+ note "Cross-deduplicate data entries by feature between subsets."
        """
        self._info("Deduplicating...")
        # for data entry accounting
        before, after = dict(), dict()

        # deduplicating rule: entries that come LATER are of higher priority
        ordered_subsets = [
            *self.__class__.SCRATCH_SUBSETS,
            *self.__class__.PUBLIC_SUBSETS,
            *self.__class__.PRIVATE_SUBSETS,
        ]

        # keep track of which df has which columns and which rows came from which subset
        columns = dict()
        for _key in ordered_subsets:
            before[_key] = self.dfs[_key].shape[0]
            columns[_key] = self.dfs[_key].columns
            self.dfs[_key]["__subset"] = _key

        # concatenate in order and deduplicate
        overall_df = pd.concat([self.dfs[_key] for _key in ordered_subsets],
                               axis=0,
                               sort=False)
        overall_df.drop_duplicates(subset=[self.__class__.FEATURE_KEY],
                                   keep="last",
                                   inplace=True)
        overall_df.reset_index(drop=True, inplace=True)

        # cut up slices
        for _key in ordered_subsets:
            self.dfs[_key] = overall_df[
                overall_df["__subset"] == _key].reset_index(
                    drop=True, inplace=False)[columns[_key]]
            after[_key] = self.dfs[_key].shape[0]
            self._info(
                f"--subset {_key} rows: {before[_key]} -> {after[_key]}.")

    def synchronize_dictl_to_df(self):
        """
        ???+ note "Re-make dataframes from lists of dictionaries."
        """
        self.dfs = dict()
        for _key, _dictl in self.dictls.items():
            if _dictl:
                _df = pd.DataFrame(_dictl)
                assert self.__class__.FEATURE_KEY in _df.columns
                assert "label" in _df.columns
            else:
                _df = pd.DataFrame(
                    columns=[self.__class__.FEATURE_KEY, "label"])

            self.dfs[_key] = _df

    def synchronize_df_to_dictl(self):
        """
        ???+ note "Re-make lists of dictionaries from dataframes."
        """
        self.dictls = dict()
        for _key, _df in self.dfs.items():
            self.dictls[_key] = _df.to_dict(orient="records")

    def compute_2d_embedding(self, vectorizer, method, **kwargs):
        """
        ???+ note "Get embeddings in the xy-plane and return the dimensionality reducer."
            Reference: [`DimensionalityReducer`](https://github.com/phurwicz/hover/blob/main/hover/core/representation/reduction.py)

            | Param        | Type       | Description                        |
            | :----------- | :--------- | :--------------------------------- |
            | `vectorizer` | `callable` | the feature -> vector function     |
            | `method`     | `str`      | arg for `DimensionalityReducer`    |
            | `**kwargs`   |            | kwargs for `DimensionalityReducer` |
        """
        from hover.core.representation.reduction import DimensionalityReducer

        # prepare input vectors to manifold learning
        fit_subset = [
            *self.__class__.SCRATCH_SUBSETS, *self.__class__.PUBLIC_SUBSETS
        ]
        trans_subset = [*self.__class__.PRIVATE_SUBSETS]

        assert not set(fit_subset).intersection(
            set(trans_subset)), "Unexpected overlap"

        # compute vectors and keep track which where to slice the array for fitting
        feature_inp = []
        for _key in fit_subset:
            feature_inp += self.dfs[_key][self.__class__.FEATURE_KEY].tolist()
        fit_num = len(feature_inp)
        for _key in trans_subset:
            feature_inp += self.dfs[_key][self.__class__.FEATURE_KEY].tolist()
        trans_arr = np.array([vectorizer(_inp) for _inp in tqdm(feature_inp)])

        # initialize and fit manifold learning reducer using specified subarray
        self._info(
            f"Fit-transforming {method.upper()} on {fit_num} samples...")
        reducer = DimensionalityReducer(trans_arr[:fit_num])
        fit_embedding = reducer.fit_transform(method, **kwargs)

        # compute embedding of the whole dataset
        self._info(
            f"Transforming {method.upper()} on {trans_arr.shape[0]-fit_num} samples..."
        )
        trans_embedding = reducer.transform(trans_arr[fit_num:], method)

        # assign x and y coordinates to dataset
        start_idx = 0
        for _subset, _embedding in [
            (fit_subset, fit_embedding),
            (trans_subset, trans_embedding),
        ]:
            # edge case: embedding is too small
            if _embedding.shape[0] < 1:
                for _key in _subset:
                    assert (self.dfs[_key].shape[0] == 0
                            ), "Expected empty df due to empty embedding"
                continue
            for _key in _subset:
                _length = self.dfs[_key].shape[0]
                self.dfs[_key]["x"] = pd.Series(
                    _embedding[start_idx:(start_idx + _length), 0])
                self.dfs[_key]["y"] = pd.Series(
                    _embedding[start_idx:(start_idx + _length), 1])
                start_idx += _length

        return reducer

    def loader(self, key, *vectorizers, batch_size=64, smoothing_coeff=0.0):
        """
        ???+ note "Prepare a torch `Dataloader` for training or evaluation."
            | Param         | Type          | Description                        |
            | :------------ | :------------ | :--------------------------------- |
            | `key`         | `str`         | subset of data, e.g. `"train"`     |
            | `vectorizers` | `callable`(s) | the feature -> vector function(s)  |
            | `batch_size`  | `int`         | size per batch                     |
            | `smoothing_coeff` | `float`   | portion of probability to equally split between classes |
        """
        # lazy import: missing torch should not break the rest of the class
        from hover.utils.torch_helper import (
            VectorDataset,
            MultiVectorDataset,
            one_hot,
            label_smoothing,
        )

        # take the slice that has a meaningful label
        df = self.dfs[key][
            self.dfs[key]["label"] != module_config.ABSTAIN_DECODED]

        # edge case: valid slice is too small
        if df.shape[0] < 1:
            raise ValueError(
                f"Subset {key} has too few samples ({df.shape[0]})")
        batch_size = min(batch_size, df.shape[0])

        # prepare output vectors
        labels = df["label"].apply(lambda x: self.label_encoder[x]).tolist()
        output_vectors = one_hot(labels, num_classes=len(self.classes))
        if smoothing_coeff > 0.0:
            output_vectors = label_smoothing(output_vectors,
                                             coefficient=smoothing_coeff)

        # prepare input vectors
        assert len(vectorizers) > 0, "Expected at least one vectorizer"
        multi_flag = len(vectorizers) > 1
        features = df[self.__class__.FEATURE_KEY].tolist()

        input_vector_lists = []
        for _vec_func in vectorizers:
            self._info(f"Preparing {key} input vectors...")
            _input_vecs = [_vec_func(_f) for _f in tqdm(features)]
            input_vector_lists.append(_input_vecs)

        self._info(f"Preparing {key} data loader...")
        if multi_flag:
            assert len(
                input_vector_lists) > 1, "Expected multiple lists of vectors"
            loader = MultiVectorDataset(
                input_vector_lists,
                output_vectors).loader(batch_size=batch_size)
        else:
            assert len(
                input_vector_lists) == 1, "Expected only one list of vectors"
            input_vectors = input_vector_lists[0]
            loader = VectorDataset(
                input_vectors, output_vectors).loader(batch_size=batch_size)
        self._good(
            f"Prepared {key} loader with {len(features)} examples; {len(vectorizers)} vectors per feature, batch size {batch_size}"
        )
        return loader
	dict["timestamp_"+attr] = []
	dict[attr] = []

# The data sources for the main plot and the task boxes
source      = ColumnDataSource(dict)
task_source = ColumnDataSource({"tasks" : [], "time" : []})

# The data source for the plo displaying the active tasks over time
session_source = ColumnDataSource({"xss": [], "yss": [], "colors": [], "tasktype": [], "running_tasks": []})

source.data, task_source.data, session_source.data = query_events(current_run)

# a dropdown menu to select a run for which data shall be visualized
run_menu = [(run_format(k, v['numLogEntries'], v['tstart']), add_prefix(k)) for k, v in run_map.iteritems() if v['tstart'] is not None]  # use None for separator
dropdown = Dropdown(label="run", button_type="warning", menu=run_menu)
dropdown.on_click(lambda newValue: select_run(rem_prefix(newValue)))

manualLegendBox = Div(text=legendFormat(task_types), width=PLOT_WIDTH, height=80)

# info boxes to display the number of log messages in this run and the elapsed wall clock time
numMessages = Div(text=infoBoxFormat("Messages", 0), width=200, height=100)
runID       = Div(text=infoBoxFormat("run", str(current_run)), width=400, height=100)
startTime   = Div(text=infoBoxFormat("start time [ms]", str(general_info["start_time"])), width=200, height=100)

# Set the properties of the first plot
p = figure(plot_height=PLOT_HEIGHT, plot_width=PLOT_WIDTH,
		   #tools="pan,xpan,wheel_zoom,xwheel_zoom,ywheel_zoom,xbox_zoom,reset",
		   tools="xpan,xwheel_zoom,xbox_zoom,reset,save",
		   toolbar_location="above",
		   x_axis_type="linear", y_axis_location="right", y_axis_type=None,
		   webgl=True)
Beispiel #17
0
    x, y = get_graph_df(df_sub)
    source_sub1.data = {'month': x, 'duration': y}
    #p.line(x='month',y='duration', line_width=2,color='blue',source=source_sub,legend_label='2020 data in zipcode 1')


# callback function for dropdown button 1
def draw_line_zip2(new):
    zc = int(new.item)
    df_sub = df[df['zip_code'] == zc]
    #source_sub2 = get_graph_df(df_sub)
    x, y = get_graph_df(df_sub)
    source_sub2.data = {'month': x, 'duration': y}
    #p.line(x='month',y='duration', line_width=2,color='green',source=source_sub,legend_label='2020 data in zipcode 2')


d1.on_click(draw_line_zip1)
d2.on_click(draw_line_zip2)


# input: df output: df for graph (two cols: month,duration)
def get_graph_df(df):

    months = df['month'].unique()
    months_list = list(months)
    months_list.sort()
    avg_dur = []

    # iterate through months, calculate mean in duration for each
    for i in months_list:
        df_filter = df[df['month'] == i]
        dur_i = df_filter['duration'].mean()
# initialize data source
source_function1 = ColumnDataSource(data=dict(x=[], y=[]))
source_function2 = ColumnDataSource(data=dict(x=[], y=[]))
source_result = ColumnDataSource(data=dict(x=[], y=[]))
source_convolution = ColumnDataSource(data=dict(x=[], y=[]))
source_xmarker = ColumnDataSource(data=dict(x=[], y=[]))
source_overlay = ColumnDataSource(data=dict(x=[], y=[], y_neg=[], y_pos=[]))

# initialize properties
update_is_enabled = True

# initialize controls
# dropdown menu for sample functions
function_type = Dropdown(label="choose a sample function pair or enter one below",
                         menu=convolution_settings.sample_function_names)
function_type.on_click(function_pair_input_change)

# slider controlling the evaluated x value of the convolved function
x_value_input = Slider(title="x value", name='x value', value=convolution_settings.x_value_init,
                       start=convolution_settings.x_value_min, end=convolution_settings.x_value_max,
                       step=convolution_settings.x_value_step)
x_value_input.on_change('value', input_change)
# text input for the first function to be convolved
function1_input = TextInput(value=convolution_settings.function1_input_init, title="my first function:")
function1_input.on_change('value', input_change)
# text input for the second function to be convolved
function2_input = TextInput(value=convolution_settings.function1_input_init, title="my second function:")
function2_input.on_change('value', input_change)

# initialize plot
toolset = "crosshair,pan,reset,resize,save,wheel_zoom"
plot.scatter('x', 'y', source=source_critical_pts, color='red', legend='critical pts')
plot.multi_line('x_ls', 'y_ls', source=source_critical_lines, color='red', legend='critical lines')

# initialize controls
# text input for input of the ode system [u,v] = [x',y']
u_input = TextInput(value=odesystem_settings.sample_system_functions[odesystem_settings.init_fun_key][0], title="u(x,y):")
v_input = TextInput(value=odesystem_settings.sample_system_functions[odesystem_settings.init_fun_key][1], title="v(x,y):")

# dropdown menu for selecting one of the sample functions
sample_fun_input = Dropdown(label="choose a sample function pair or enter one below",
                            menu=odesystem_settings.sample_system_names)

# Interactor for entering starting point of initial condition
interactor = my_bokeh_utils.Interactor(plot)

# initialize callback behaviour
sample_fun_input.on_click(sample_fun_change)
u_input.on_change('value', ode_change)
v_input.on_change('value', ode_change)
interactor.on_click(initial_value_change)

# calculate data
init_data()

# lists all the controls in our app associated with the default_funs panel
function_controls = widgetbox(sample_fun_input, u_input, v_input,width=400)

# refresh quiver field and streamline all 100ms
curdoc().add_periodic_callback(refresh_user_view, 100)
# make layout
curdoc().add_root(column(plot, function_controls))
Beispiel #20
0
        bar_chart.visible = False
        bar_chart1.visible = False
        bar_chart420.visible = True
        bar_chart330.visible = False
    elif new == '3e Klass':
        bar_chart.visible = False
        bar_chart1.visible = False
        bar_chart420.visible = False
        bar_chart330.visible = True


# Create a dropdown Select widget: select
selectRegio = Dropdown(label="Maak keuze uit de klasse", menu=["All", "1e Klass", "2e Klass", "3e Klass"])

# Attach the update_plot callback to the 'value' property of select
selectRegio.on_click(update_bar_chart)

# Hoofdstuk 3 Scatter en 2-D visualisatie
titel4 = Div(text="<h2"">Hoofdstuk 3: Scatterplots en 2-D visualisaties""</h2>", width=800, height=50)
text4 = Div(text="<h4"">Hieronder volgt scatterplot van prijs per ticket tegenover de GDP per capita""</h4>", width=800,
            height=50)
# dataframes
FareVsGDP = pd.read_csv('Data K\FareVsGDP.csv')
FareVsGDP1 = pd.read_csv('Data K\FareVsGDP1eKlass.csv')
FareVsGDP2 = pd.read_csv('Data K\FareVsGDP2eKlass.csv')
FareVsGDP3 = pd.read_csv('Data K\FareVsGDP3eKlass.csv')

# dropdown
menu = [("Class 1", "Class_1"), ("Class 2", "Class_2"), None, ("Class 3", "Class_3")]

# Create ColumnDataSource: source
def create(palm):
    energy_min = palm.energy_range.min()
    energy_max = palm.energy_range.max()
    energy_npoints = palm.energy_range.size

    current_results = (0, 0, 0, 0)

    doc = curdoc()

    # Streaked and reference waveforms plot
    waveform_plot = Plot(
        title=Title(text="eTOF waveforms"),
        x_range=DataRange1d(),
        y_range=DataRange1d(),
        plot_height=PLOT_CANVAS_HEIGHT,
        plot_width=PLOT_CANVAS_WIDTH,
        toolbar_location="right",
    )

    # ---- tools
    waveform_plot.toolbar.logo = None
    waveform_plot.add_tools(PanTool(), BoxZoomTool(), WheelZoomTool(),
                            ResetTool())

    # ---- axes
    waveform_plot.add_layout(LinearAxis(axis_label="Photon energy, eV"),
                             place="below")
    waveform_plot.add_layout(LinearAxis(axis_label="Intensity",
                                        major_label_orientation="vertical"),
                             place="left")

    # ---- grid lines
    waveform_plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    waveform_plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- line glyphs
    waveform_source = ColumnDataSource(
        dict(x_str=[], y_str=[], x_ref=[], y_ref=[]))
    waveform_ref_line = waveform_plot.add_glyph(
        waveform_source, Line(x="x_ref", y="y_ref", line_color="blue"))
    waveform_str_line = waveform_plot.add_glyph(
        waveform_source, Line(x="x_str", y="y_str", line_color="red"))

    # ---- legend
    waveform_plot.add_layout(
        Legend(items=[("reference",
                       [waveform_ref_line]), ("streaked",
                                              [waveform_str_line])]))
    waveform_plot.legend.click_policy = "hide"

    # Cross-correlation plot
    xcorr_plot = Plot(
        title=Title(text="Waveforms cross-correlation"),
        x_range=DataRange1d(),
        y_range=DataRange1d(),
        plot_height=PLOT_CANVAS_HEIGHT,
        plot_width=PLOT_CANVAS_WIDTH,
        toolbar_location="right",
    )

    # ---- tools
    xcorr_plot.toolbar.logo = None
    xcorr_plot.add_tools(PanTool(), BoxZoomTool(), WheelZoomTool(),
                         ResetTool())

    # ---- axes
    xcorr_plot.add_layout(LinearAxis(axis_label="Energy shift, eV"),
                          place="below")
    xcorr_plot.add_layout(LinearAxis(axis_label="Cross-correlation",
                                     major_label_orientation="vertical"),
                          place="left")

    # ---- grid lines
    xcorr_plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    xcorr_plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- line glyphs
    xcorr_source = ColumnDataSource(dict(lags=[], xcorr1=[], xcorr2=[]))
    xcorr_plot.add_glyph(
        xcorr_source,
        Line(x="lags", y="xcorr1", line_color="purple", line_dash="dashed"))
    xcorr_plot.add_glyph(xcorr_source,
                         Line(x="lags", y="xcorr2", line_color="purple"))

    # ---- vertical span
    xcorr_center_span = Span(location=0, dimension="height")
    xcorr_plot.add_layout(xcorr_center_span)

    # Delays plot
    pulse_delay_plot = Plot(
        title=Title(text="Pulse delays"),
        x_range=DataRange1d(),
        y_range=DataRange1d(),
        plot_height=PLOT_CANVAS_HEIGHT,
        plot_width=PLOT_CANVAS_WIDTH,
        toolbar_location="right",
    )

    # ---- tools
    pulse_delay_plot.toolbar.logo = None
    pulse_delay_plot.add_tools(PanTool(), BoxZoomTool(), WheelZoomTool(),
                               ResetTool())

    # ---- axes
    pulse_delay_plot.add_layout(LinearAxis(axis_label="Pulse number"),
                                place="below")
    pulse_delay_plot.add_layout(
        LinearAxis(axis_label="Pulse delay (uncalib), eV",
                   major_label_orientation="vertical"),
        place="left",
    )

    # ---- grid lines
    pulse_delay_plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    pulse_delay_plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- line glyphs
    pulse_delay_source = ColumnDataSource(dict(pulse=[], delay=[]))
    pulse_delay_plot.add_glyph(
        pulse_delay_source, Line(x="pulse", y="delay", line_color="steelblue"))

    # ---- vertical span
    pulse_delay_plot_span = Span(location=0, dimension="height")
    pulse_delay_plot.add_layout(pulse_delay_plot_span)

    # Pulse lengths plot
    pulse_length_plot = Plot(
        title=Title(text="Pulse lengths"),
        x_range=DataRange1d(),
        y_range=DataRange1d(),
        plot_height=PLOT_CANVAS_HEIGHT,
        plot_width=PLOT_CANVAS_WIDTH,
        toolbar_location="right",
    )

    # ---- tools
    pulse_length_plot.toolbar.logo = None
    pulse_length_plot.add_tools(PanTool(), BoxZoomTool(), WheelZoomTool(),
                                ResetTool())

    # ---- axes
    pulse_length_plot.add_layout(LinearAxis(axis_label="Pulse number"),
                                 place="below")
    pulse_length_plot.add_layout(
        LinearAxis(axis_label="Pulse length (uncalib), eV",
                   major_label_orientation="vertical"),
        place="left",
    )

    # ---- grid lines
    pulse_length_plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    pulse_length_plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- line glyphs
    pulse_length_source = ColumnDataSource(dict(x=[], y=[]))
    pulse_length_plot.add_glyph(pulse_length_source,
                                Line(x="x", y="y", line_color="steelblue"))

    # ---- vertical span
    pulse_length_plot_span = Span(location=0, dimension="height")
    pulse_length_plot.add_layout(pulse_length_plot_span)

    # Folder path text input
    def path_textinput_callback(_attr, _old_value, new_value):
        save_textinput.value = new_value
        path_periodic_update()

    path_textinput = TextInput(title="Folder Path:",
                               value=os.path.join(os.path.expanduser("~")),
                               width=510)
    path_textinput.on_change("value", path_textinput_callback)

    # Saved runs dropdown menu
    def h5_update(pulse, delays, debug_data):
        prep_data, lags, corr_res_uncut, corr_results = debug_data

        waveform_source.data.update(
            x_str=palm.energy_range,
            y_str=prep_data["1"][pulse, :],
            x_ref=palm.energy_range,
            y_ref=prep_data["0"][pulse, :],
        )

        xcorr_source.data.update(lags=lags,
                                 xcorr1=corr_res_uncut[pulse, :],
                                 xcorr2=corr_results[pulse, :])

        xcorr_center_span.location = delays[pulse]
        pulse_delay_plot_span.location = pulse
        pulse_length_plot_span.location = pulse

    # this placeholder function should be reassigned in 'saved_runs_dropdown_callback'
    h5_update_fun = lambda pulse: None

    def update_saved_runs(saved_run):
        new_value = saved_run
        if new_value != "Saved Runs":
            nonlocal h5_update_fun, current_results
            saved_runs_dropdown.label = new_value
            filepath = os.path.join(path_textinput.value, new_value)
            tags, delays, lengths, debug_data = palm.process_hdf5_file(
                filepath, debug=True)
            current_results = (new_value, tags, delays, lengths)

            if autosave_checkbox.active:
                save_button_callback()

            pulse_delay_source.data.update(pulse=np.arange(len(delays)),
                                           delay=delays)
            pulse_length_source.data.update(x=np.arange(len(lengths)),
                                            y=lengths)
            h5_update_fun = partial(h5_update,
                                    delays=delays,
                                    debug_data=debug_data)

            pulse_slider.end = len(delays) - 1
            pulse_slider.value = 0
            h5_update_fun(0)

    def saved_runs_dropdown_callback(event):
        update_saved_runs(event.item)

    saved_runs_dropdown = Dropdown(label="Saved Runs",
                                   button_type="primary",
                                   menu=[])
    saved_runs_dropdown.on_click(saved_runs_dropdown_callback)

    # ---- saved run periodic update
    def path_periodic_update():
        new_menu = []
        if os.path.isdir(path_textinput.value):
            for entry in os.scandir(path_textinput.value):
                if entry.is_file() and entry.name.endswith((".hdf5", ".h5")):
                    new_menu.append((entry.name, entry.name))
        saved_runs_dropdown.menu = sorted(new_menu, reverse=True)

    doc.add_periodic_callback(path_periodic_update, 5000)

    # Pulse number slider
    def pulse_slider_callback(_attr, _old_value, new_value):
        h5_update_fun(pulse=new_value)

    pulse_slider = Slider(
        start=0,
        end=99999,
        value_throttled=0,
        step=1,
        title="Pulse ID",
    )
    pulse_slider.on_change("value_throttled", pulse_slider_callback)

    # Energy maximal range value text input
    def energy_max_spinner_callback(_attr, old_value, new_value):
        nonlocal energy_max
        if new_value > energy_min:
            energy_max = new_value
            palm.energy_range = np.linspace(energy_min, energy_max,
                                            energy_npoints)
            update_saved_runs(saved_runs_dropdown.label)
        else:
            energy_max_spinner.value = old_value

    energy_max_spinner = Spinner(title="Maximal Energy, eV:",
                                 value=energy_max,
                                 step=0.1)
    energy_max_spinner.on_change("value", energy_max_spinner_callback)

    # Energy minimal range value text input
    def energy_min_spinner_callback(_attr, old_value, new_value):
        nonlocal energy_min
        if new_value < energy_max:
            energy_min = new_value
            palm.energy_range = np.linspace(energy_min, energy_max,
                                            energy_npoints)
            update_saved_runs(saved_runs_dropdown.label)
        else:
            energy_min_spinner.value = old_value

    energy_min_spinner = Spinner(title="Minimal Energy, eV:",
                                 value=energy_min,
                                 step=0.1)
    energy_min_spinner.on_change("value", energy_min_spinner_callback)

    # Energy number of interpolation points text input
    def energy_npoints_spinner_callback(_attr, old_value, new_value):
        nonlocal energy_npoints
        if new_value > 1:
            energy_npoints = new_value
            palm.energy_range = np.linspace(energy_min, energy_max,
                                            energy_npoints)
            update_saved_runs(saved_runs_dropdown.label)
        else:
            energy_npoints_spinner.value = old_value

    energy_npoints_spinner = Spinner(title="Number of interpolation points:",
                                     value=energy_npoints)
    energy_npoints_spinner.on_change("value", energy_npoints_spinner_callback)

    # Save location
    save_textinput = TextInput(title="Save Folder Path:",
                               value=os.path.join(os.path.expanduser("~")))

    # Autosave checkbox
    autosave_checkbox = CheckboxButtonGroup(labels=["Auto Save"],
                                            active=[],
                                            width=250)

    # Save button
    def save_button_callback():
        if current_results[0]:
            filename, tags, delays, lengths = current_results
            save_filename = os.path.splitext(filename)[0] + ".csv"
            df = pd.DataFrame({
                "pulse_id": tags,
                "pulse_delay": delays,
                "pulse_length": lengths
            })
            df.to_csv(os.path.join(save_textinput.value, save_filename),
                      index=False)

    save_button = Button(label="Save Results",
                         button_type="default",
                         width=250)
    save_button.on_click(save_button_callback)

    # assemble
    tab_layout = column(
        row(
            column(waveform_plot, xcorr_plot),
            Spacer(width=30),
            column(
                path_textinput,
                saved_runs_dropdown,
                pulse_slider,
                Spacer(height=30),
                energy_min_spinner,
                energy_max_spinner,
                energy_npoints_spinner,
                Spacer(height=30),
                save_textinput,
                autosave_checkbox,
                save_button,
            ),
        ),
        row(pulse_delay_plot, Spacer(width=10), pulse_length_plot),
    )

    return Panel(child=tab_layout, title="HDF5 File")