Exemplo n.º 1
0
def main(options, args):

    #logger = log.get_logger("ginga", options=options)
    logger = log.get_logger("ginga", level=20, log_file="/tmp/ginga.log")

    #TOOLS = "pan,wheel_zoom,box_select,tap"
    TOOLS = "box_select"

    # create a new plot with default tools, using figure
    fig = figure(x_range=[0, 600],
                 y_range=[0, 600],
                 plot_width=600,
                 plot_height=600,
                 tools=TOOLS)

    viewer = ib.CanvasView(logger)
    viewer.set_figure(fig)

    bd = viewer.get_bindings()
    bd.enable_all(True)

    ## box_select_tool = fig.select(dict(type=BoxSelectTool))
    ## box_select_tool.select_every_mousemove = True
    #tap_tool = fig.select_one(TapTool).renderers = [cr]

    # open a session to keep our local document in sync with server
    #session = push_session(curdoc())

    #curdoc().add_periodic_callback(update, 50)

    def load_file(path):
        image = load_data(path, logger=logger)
        viewer.set_image(image)

    def load_file_cb(attr_name, old_val, new_val):
        #print(attr_name, old_val, new_val)
        load_file(new_val)

    def zoom_ctl_cb(attr_name, old_val, new_val):
        if new_val >= 0:
            new_val += 2
        viewer.zoom_to(int(new_val))
        scale = viewer.get_scale()
        logger.info("%f" % scale)
        viewer.onscreen_message("%f" % (scale), delay=0.3)

    # add a entry widget and configure with the call back
    #dstdir = options.indir
    dstdir = ""
    path_w = TextInput(value=dstdir, title="File:")
    path_w.on_change('value', load_file_cb)

    slide = Slider(start=-20, end=20, step=1, value=1)
    slide.on_change('value', zoom_ctl_cb)

    layout = column(fig, path_w, slide)
    curdoc().add_root(layout)

    if len(args) > 0:
        load_file(args[0])
Exemplo n.º 2
0
def modify_doc(doc):
    source = ColumnDataSource(data=get_data(200))

    p = figure(toolbar_location=None)
    r = p.circle(x='x',
                 y='y',
                 radius='r',
                 source=source,
                 color="navy",
                 alpha=0.6,
                 line_color="white")

    select = Select(title="Color", value="navy", options=COLORS)
    input = TextInput(title="Number of points", value="200")

    def update_color(attrname, old, new):
        r.glyph.fill_color = select.value

    select.on_change('value', update_color)

    def update_points(attrname, old, new):
        N = int(input.value)
        source.data = get_data(N)

    input.on_change('value', update_points)

    layout = column(row(select, input, width=400), row(p))

    doc.add_root(layout)
Exemplo n.º 3
0
def main(options, args):

    #logger = log.get_logger("ginga", options=options)
    logger = log.get_logger("ginga", level=20, log_file="/tmp/ginga.log")

    #TOOLS = "pan,wheel_zoom,box_select,tap"
    TOOLS = "box_select"

    # create a new plot with default tools, using figure
    fig = figure(x_range=[0, 600], y_range=[0, 600],
                 plot_width=600, plot_height=600, tools=TOOLS)

    viewer = ib.CanvasView(logger)
    viewer.set_figure(fig)

    bd = viewer.get_bindings()
    bd.enable_all(True)

    ## box_select_tool = fig.select(dict(type=BoxSelectTool))
    ## box_select_tool.select_every_mousemove = True
    #tap_tool = fig.select_one(TapTool).renderers = [cr]

    # open a session to keep our local document in sync with server
    #session = push_session(curdoc())

    #curdoc().add_periodic_callback(update, 50)

    def load_file(path):
        image = load_data(path, logger=logger)
        viewer.set_image(image)

    def load_file_cb(attr_name, old_val, new_val):
        #print(attr_name, old_val, new_val)
        load_file(new_val)

    def zoom_ctl_cb(attr_name, old_val, new_val):
        if new_val >= 0:
            new_val += 2
        viewer.zoom_to(int(new_val))
        scale = viewer.get_scale()
        logger.info("%f" % scale)
        viewer.onscreen_message("%f" % (scale), delay=0.3)

    # add a entry widget and configure with the call back
    #dstdir = options.indir
    dstdir = ""
    path_w = TextInput(value=dstdir, title="File:")
    path_w.on_change('value', load_file_cb)

    slide = Slider(start=-20, end=20, step=1, value=1)
    slide.on_change('value', zoom_ctl_cb)

    layout = column(fig, path_w, slide)
    curdoc().add_root(layout)

    if len(args) > 0:
        load_file(args[0])
Exemplo n.º 4
0
def modify_doc(doc):
    source = ColumnDataSource(dict(x=[1, 2], y=[1, 1], val=["a", "b"]))
    plot = Plot(plot_height=400, plot_width=400, x_range=Range1d(0, 1), y_range=Range1d(0, 1), min_border=0)
    plot.add_glyph(source, Circle(x='x', y='y', size=20))
    plot.add_tools(CustomAction(callback=CustomJS(args=dict(s=source), code=RECORD("data", "s.data"))))
    text_input = TextInput(css_classes=["foo"])
    def cb(attr, old, new):
        source.data['val'] = [old, new]
    text_input.on_change('value', cb)
    doc.add_root(column(text_input, plot))
Exemplo n.º 5
0
def modify_doc(doc):
    source = ColumnDataSource(dict(x=[1, 2], y=[1, 1], val=["a", "b"]))
    plot = Plot(plot_height=400, plot_width=400, x_range=Range1d(0, 1), y_range=Range1d(0, 1), min_border=0)
    plot.add_glyph(source, Circle(x='x', y='y', size=20))
    plot.add_tools(CustomAction(callback=CustomJS(args=dict(s=source), code=RECORD("data", "s.data"))))
    text_input = TextInput(css_classes=["foo"])
    def cb(attr, old, new):
        source.data['val'] = [old, new]
    text_input.on_change('value', cb)
    doc.add_root(column(text_input, plot))
Exemplo n.º 6
0
 def create_label_text_inputs(self) -> list:
     label_text_inputs = []
     for i, category in enumerate(self.categories):
         title = "Label" if i == 0 else ""
         label_text_input = TextInput(value=category,
                                      title=title,
                                      width=100)
         label_text_input.on_change(
             "value", partial(self.handle_label_text_input_change, i))
         label_text_inputs.append(label_text_input)
     return label_text_inputs
Exemplo n.º 7
0
 def create_label_inputs(self) -> column:
     inputs = []
     for i, label_item in enumerate(self.legend.items):
         value = self.read_item_label(i)
         disabled = value.startswith("$")
         label_input = TextInput(value=value, disabled=disabled, width=210)
         label_input.on_change("value",
                               partial(self.handle_label_item_change, i))
         inputs.append(label_input)
     title = Div(text="Labels")
     return column(title, *inputs)
Exemplo n.º 8
0
def create_panel(title, params, size=225):
    """
    return a Panel object containing a column layout of TextInput widgets for setting parameters
    """
    wlist = [widgetbox(Div(text='<center><strong>---' + title + '---</strong></center>'), width=size)]
    for param in params.keys():
        value = str(r0[title][param])
        text = TextInput(value=value, title=param)
        text.on_change('value', update_figs)
        wlist.append(widgetbox(text, width=size))
    return column(wlist)
Exemplo n.º 9
0
def create_panel(title, params, size=40):
    """
    return a Panel object containing a column layout of TextInput widgets for setting parameters
    """
    wlist = []
    for param in params.keys():
        value = str(r0[title][param])
        text = TextInput(value=value, title=param)
        text.on_change('value', update_figs)
        wlist.append(widgetbox(text, height=size))

    return Panel(child=column(wlist), title=title)
Exemplo n.º 10
0
class SacredRunsConfigAnalyzer(BokehComponent):
    _sacred_utils: SacredUtils
    _on_run_selected: Optional[Callable[[int], None]]

    def __init__(self,
                 sacred_config: SacredConfig,
                 on_run_selected: Optional[Callable[[int], None]],
                 min_id: int = 2258,
                 max_id: int = 2325):
        self._sacred_config = sacred_config
        self._on_run_selected = on_run_selected
        self._sacred_utils = SacredUtils(self._sacred_config)
        self._min_id = min_id
        self._max_id = max_id

    def _update_by_input(self):
        try:
            min_id = int(self.widget_min_id.value)
            max_id = int(self.widget_max_id.value)
            diff_result = self._sacred_utils.config_diff(
                list(range(min_id, max_id + 1)))
            formatted_config = diff_result.common_as_text('<br/>')
            df = diff_result.diff_as_df()
            self.ds_common.data = df
            self.widget_dt_common.columns = [
                TableColumn(field=c, title=c) for c in df.columns
            ]
            self.widget_config_common.text = f'<pre>{formatted_config} <hr/></pre>'
            # self.widget_config_common.text = f'<pre>{formatted_config} <hr/>{formatted_diff}</pre>'
        except Exception as e:
            print(f'Exception: {e}')

    def _on_table_click(self, p):
        run_ids = self.ds_common.data['run_ids'][p][0]
        if self._on_run_selected is not None:
            self._on_run_selected(run_ids[0])

    def create_layout(self):
        self.ds_common = ColumnDataSource({'a': [4]})
        self.ds_common.selected.on_change(
            'indices', lambda a, o, n: self._on_table_click(n))
        self.widget_min_id = TextInput(title='Min Id', value=str(self._min_id))
        self.widget_min_id.on_change('value',
                                     lambda a, o, n: self._update_by_input())
        self.widget_max_id = TextInput(title='Max Id', value=str(self._max_id))
        self.widget_max_id.on_change('value',
                                     lambda a, o, n: self._update_by_input())
        self.widget_config_common = Div(text='')
        self.widget_dt_common = DataTable(source=self.ds_common, width=1000)
        self._update_by_input()
        return column(row(self.widget_min_id, self.widget_max_id),
                      self.widget_config_common, self.widget_dt_common)
Exemplo n.º 11
0
class TextInputComponent(ComponentMixin):
    def __init__(self, text_input_kwargs):
        super().__init__()
        self.text_input = TextInput(**text_input_kwargs)
        self.layout = self.text_input
        self.input_text_callback = None

    def set_mediator(self, mediator):
        super().set_mediator(mediator)
        event_name = 'text-change'
        text_change = self.make_attr_old_new_callback(event_name)
        self.input_text_callback = text_change
        self.text_input.on_change('value', self.input_text_callback)
Exemplo n.º 12
0
class SacredRunsConfigAnalyzer:
    def __init__(self, mongo_observer: MongoObserver):
        self._observer = mongo_observer

    def analyze_runs(
        self, run_ids: List[int]
    ) -> Tuple[Dict[str, Any], Dict[List[Tuple[str, Any]], List[int]]]:
        runs = self._observer.runs.find({'_id': {
            '$in': run_ids
        }}, {
            '_id': 1,
            'config': 1
        })
        items = list(runs)
        common_keys = find_common_keys(items, lambda x: x['config'])
        result = group_dicts(
            items, lambda x: dict_omit_keys(x['config'],
                                            set(common_keys) | {'seed'}),
            lambda x: x['_id'])
        return common_keys, result

    def update_by_input(self):
        try:
            min_id = int(self.widget_min_id.value)
            max_id = int(self.widget_max_id.value)
            common_keys, diff = self.analyze_runs(
                list(range(min_id, max_id + 1)))

            def join_dict(delimiter: str, data: Dict):
                return delimiter.join([f'{k}: {v}' for k, v in data.items()])

            formatted_config = join_dict('<br/>', common_keys)
            formatted_diff = '<br/>'.join([
                f'{join_dict(", ", tuple_to_dict(k))}: {v}'
                for k, v in diff.items()
            ])
            self.widget_config_common.text = f'<pre>{formatted_config} <hr/>{formatted_diff}</pre>'
        except Exception as e:
            print(f'Exception: {e}')

    def create_layout(self):
        self.widget_min_id = TextInput(title='Min Id', value='1957')
        self.widget_min_id.on_change('value',
                                     lambda a, o, n: self.update_by_input())
        self.widget_max_id = TextInput(title='Max Id', value='1970')
        self.widget_max_id.on_change('value',
                                     lambda a, o, n: self.update_by_input())
        self.widget_config_common = Div(text='')
        self.update_by_input()
        return column(row(self.widget_min_id, self.widget_max_id),
                      self.widget_config_common)
Exemplo n.º 13
0
def emission_lines(fname):
    def line_update(attr, old, new):
        # TODO: make it faster by finding a method to only modify the changed object
        z_list = [(1 + float(z.value)) for z in z_in_list]
        for i in range(len(line_list)):
            for j in range(len(line_list[i])):
                # j: jth line for ith element
                line_list[i][j].visible = True if not (checkboxes[i].active
                                                       == []) else False
                line_list[i][j].location = emission_lines[i][j] * z_list[i]

    # variables
    emission_lines = np.genfromtxt(fname,
                                   skip_header=1,
                                   delimiter=',',
                                   unpack=True)
    linename_list = np.loadtxt(fname,
                               delimiter=',',
                               unpack=True,
                               max_rows=1,
                               dtype='str')
    line_list = []  # line_list[i][j] for jth line in ith element
    checkboxes = []  # checkboxes[i] for ith element
    z_in_list = []  # z_in_list[i]  for ith element
    v_in_list = []  # v_in_list[i]  for ith element

    # generate widgets
    for i in range(len(emission_lines)):
        element = linename_list[i]
        print('Setting lines for ', element, '...')
        b_tmp = CheckboxGroup(labels=[element], active=[])
        b_tmp.on_change('active', line_update)
        checkboxes.append(b_tmp)
        #z_tmp = TextInput(value='2',title=element,sizing_mode='scale_width')
        z_tmp = TextInput(value='0', sizing_mode='scale_width')
        z_tmp.on_change('value', line_update)
        z_in_list.append(z_tmp)
        lines_for_this_element = []
        for j in range(len(emission_lines[i])):
            wavelength = emission_lines[i][j]
            if not np.isnan(wavelength):
                print('\t* lambda = ', emission_lines[i][j])
                line_tmp = Span(location=wavelength,
                                dimension='height',
                                line_color='orange',
                                line_width=1)  # TODO: line color
                lines_for_this_element.append(line_tmp)
        line_list.append(lines_for_this_element)
    line_update('', '', '')
    return line_list, z_in_list, checkboxes
Exemplo n.º 14
0
class PVTextBox():

    def __init__(self, pvname, pvdef):

        self.pvname = pvname
        title = (pvname).split(':')[-1].replace('_',' ')+' ('+pvdef['unit']+')'
        value = str(pvdef['value'])
        self.unit = pvdef['unit']

        self.text_input = TextInput(value=value, title=title)
        self.text_input.on_change("value", self.set_pv)

    def set_pv(self, attr, old, new):
        caput(self.pvname, new)
Exemplo n.º 15
0
class WidthWidget:
    "sets the width of the peaks"
    _widget: TextInput

    def __init__(self, ctrl, mdl, *_):
        self._ctrl = ctrl
        self._mdl = mdl
        self._config = ctrl.theme.swapmodels(ConsensusConfig())
        self._theme = ctrl.display.swapmodels(WidthWidgetTheme())

    def addtodoc(self, mainview, ctrl, *_):
        "sets-up the gui"
        self._widget = TextInput(placeholder=self._theme.placeholder,
                                 width=self._theme.width,
                                 height=self._theme.height,
                                 **self.__data())

        def _on_cb(attr, old, new):
            if not mainview.isactive():
                return

            try:
                val = float(new) if new.strip() else None
            except ValueError:
                self._widget.update(**self.__data())
                return

            instr = self._mdl.instrument
            cpy = self._config[instr]
            if cpy.precision == val:
                return

            cpy = copy(cpy)
            cpy.precision = val
            with ctrl.action:
                self._ctrl.theme.update(self._config, **{instr: cpy})

        self._widget.on_change("value", _on_cb)
        return [self._widget]

    def reset(self, cache):
        "reset the widget"
        cache[self._widget].update(**self.__data())

    def __data(self):
        prec = self._config[self._mdl.instrument].precision
        return dict(
            value="" if prec is None else self._theme.format.format(prec))
Exemplo n.º 16
0
def make_trace_figure(trace, tids):

    tooltips = [("Task:", "@name"), ("Start:", "@start"),
                ("Duration:", "@duration")]
    mint = min(trace['start'])
    maxt = max(trace['end'])
    plot = figure(width=1500,
                  height=800,
                  tooltips=tooltips,
                  y_range=tids,
                  x_range=(mint, maxt))
    plot.xaxis.axis_label = "Time [s.]"
    plot.yaxis.axis_label = "Thread/rank"

    source = ColumnDataSource(trace)
    rect = plot.add_glyph(
        source,
        Quad(left='start',
             right='end',
             top='top',
             bottom='bottom',
             fill_color='color',
             line_color='color',
             fill_alpha=0.9,
             line_alpha=0.9))

    ## Create filter
    text_input = TextInput(value="", title="Filter")

    def text_input_fn(attr, old, new):
        fil = text_input.value
        new_trace = trace[trace['name'].str.contains(fil)]
        print("Filtering using {}, originally {} rows, now {} rows".format(
            fil, trace.shape[0], new_trace.shape[0]))
        source.data = new_trace
        print("Done filtering")

    text_input.on_change('value', text_input_fn)

    print("Done preparing plot...")
    return text_input, plot
Exemplo n.º 17
0
    def get_widgets(self):
        """ Create the widgets. """

        # Text box for FDR cut-off
        fdr_input = TextInput(value="1e-12",
                              title="FDR cut-off (from 0 to 1):")

        def fdr_input_update(attr, old, new):
            """ Update FDR cutoff value.
             Update the gene slider so that the user can select only correct values. """
            self.cutoff = float(new)
            genes_to_plot.end = self.calculate_n_genes_available()

        fdr_input.on_change("value", fdr_input_update)

        genes_to_plot = Slider(start=2,
                               end=self.n_genes_relevant,
                               value=self.n_genes_relevant,
                               step=1,
                               title="Number of genes to plot with given FDR")

        def genes_to_plot_update(attr, old, new):
            """ Update the number of genes which are to supposed to be plotted.
            :param attr: value to change
            :type attr: str
            :param old: old value
            :type old: float
            :param new: new value
            :type new: float
            """
            self.n_genes_to_show = int(new)

        genes_to_plot.on_change("value", genes_to_plot_update)

        run_button = Button(label="Run", button_type="success")
        run_button.on_click(self.callback)

        return [fdr_input, genes_to_plot, run_button]
def add_item():
    item_price_input = TextInput(value="[Preis]", width=140)
    item_name_input = TextInput(value="[Produkt]", width=140)
    toggle_as_coupon_button = CheckboxButtonGroup(labels=['G'],
                                                  active=[],
                                                  width_policy="min")
    remove_item_button = Button(label='X',
                                width_policy='min',
                                button_type='warning')
    this_layout = row(item_price_input, item_name_input,
                      toggle_as_coupon_button, remove_item_button)

    def remove_this_item(*args):
        item_data_column_layout.children.remove(this_layout)
        if len(item_data_column_layout.children) == 0:
            add_item()  #avoid empty layout
        else:
            update_coupon_and_endsum()

    remove_item_button.on_click(remove_this_item)
    toggle_as_coupon_button.on_click(update_coupon_and_endsum)
    item_price_input.on_change('value', update_coupon_and_endsum_wrapper)
    item_data_column_layout.children.append(this_layout)
    update_coupon_and_endsum()
Exemplo n.º 19
0
def spotify_graph_handler(doc: Document) -> None:
    plot_handler = PlotHandler()

    def playlist_input_handler(attr, old, new):
        root_layout = curdoc().get_model_by_name('main_layout')
        sub_layouts = root_layout.children
        sub_layouts[-1] = plot_handler.get_plot(RequestType.Playlist, new)

    playlist_input = TextInput(value="", title="Playlist:")
    playlist_input.on_change("value", playlist_input_handler)

    def artist_input_handler(attr, old, new):
        root_layout = curdoc().get_model_by_name('main_layout')
        sub_layouts = root_layout.children
        sub_layouts[-1] = plot_handler.get_plot(RequestType.SingleArtist, new)

    artist_input = TextInput(value="", title="Artist:")
    artist_input.on_change("value", artist_input_handler)

    doc.add_root(
        column(playlist_input,
               artist_input,
               plot_handler.plot,
               name='main_layout'))
Exemplo n.º 20
0
    def get_widgets(self):
        """ Create p-value cut-off slider. """
        def pvalue_callback(attr, old, new):
            """ Callback for p-value cut-off slider. Refresh plot each time p-value cut-off changes.
            :param attr: value to change
            :type attr: str
            :param old: old value
            :type old: float
            :param new: new value
            :type new: float
             """

            self.pvalue_cutoff = float(new)
            self.callback()

        def fc_callback(attr, old, new):
            """ Callback for FC cut-off slider. Refresh plot each time FC cut-off changes.
            :param attr: value to change
            :type attr: str
            :param old: old value
            :type old: float
            :param new: new value
            :type new: float
             """

            self.fc_cutoff = float(new)
            self.callback()

        pvalue_input = TextInput(value="0.05",
                                 title="P-value cut-off (from 0 to 1):")
        pvalue_input.on_change("value", pvalue_callback)

        fc_input = TextInput(value="0", title="LogFC threshold to colour")
        fc_input.on_change("value", fc_callback)

        return [pvalue_input, fc_input]
Exemplo n.º 21
0
class SelectCustomLine(BaseWidget):
    """Produces a widget for selecting a custom line (counter)"""
    def __init__(self,
                 doc,
                 idx,
                 plots,
                 callback=None,
                 refresh_rate=500,
                 collection=None,
                 **kwargs):
        super().__init__(doc,
                         callback=callback,
                         refresh_rate=refresh_rate,
                         collection=collection,
                         **kwargs)

        self._countername_autocomplete = AutocompleteInput(
            name=f"Autocomplete_{BaseWidget.instance_num}",
            title="Countername:",
            completions=counternames,
            width=200,
        )

        self._collection_widget = DataCollectionSelect(doc,
                                                       self._set_collection,
                                                       width=120)
        self._selected_collection = None

        self._name = f"Line {idx}"
        self._name_edit = TextInput(title="Change name:",
                                    value=self._name,
                                    width=150)
        self._name_edit.on_change("value", self._change_name)
        self._title = Div(text=f"<h3>{self._name}</h3>")

        self._delete = Button(label="Remove", width=70, button_type="danger")
        self._delete.on_click(lambda: callback(idx))
        self._to_plot = Select(options=plots,
                               value=plots[0],
                               title="To plot:",
                               width=70)

        # Instance infos
        self._locality_input = TextInput(title="Locality #id:",
                                         value="0",
                                         width=70)
        self._locality_select = Select(options=[],
                                       title="Locality #id:",
                                       value="0",
                                       width=70)
        self._thread_id = TextInput(title="Worker #id:", width=70, value="0")
        self._pool = TextInput(title="Pool name:", width=70)
        self._pool_select = Select(options=[], title="Pool name:", width=70)
        self._is_total = RadioGroup(labels=["Yes", "No"], active=0, width=30)
        self._is_total.on_change("active", self._change_is_total)

        self._root = column(
            row(self._title, self._name_edit),
            self._delete,
            row(
                self._to_plot,
                self._collection_widget.layout(),
                self._countername_autocomplete,
                self._locality_input,
                self._pool,
                row(Div(text="Is total?"), self._is_total),
                empty_placeholder(),
            ),
        )

    def _change_name(self, old, attr, new):
        self._name = new
        self._title.text = f"<h3>{new}</h3>"

    def _change_is_total(self, old, attr, new):
        if new:
            self._root.children[2].children[6] = self._thread_id
            self._pool.value = "default"
            if "default" in self._pool_select.options:
                self._pool_select.value = "default"
        else:
            self._pool.value = ""
            if "No pool" in self._pool_select.options:
                self._pool_select.value = "No pool"
            self._root.children[2].children[6] = empty_placeholder()

    def _set_collection(self, collection):
        self._selected_collection = collection
        if collection:
            self._countername_autocomplete.completions = collection.get_counter_names(
            )
            self._locality_select.options = collection.get_localities()
            self._pool_select.options = [
                "No pool" if not pool else pool
                for pool in collection.get_pools(self._locality_input.value)
            ]
            if "No pool" in self._pool_select.options:
                self._pool_select.value = "No pool"

            self._root.children[2].children[3] = self._locality_select
            self._root.children[2].children[4] = self._pool_select
        else:
            self._countername_autocomplete.completions = counternames
            self._root.children[2].children[3] = self._locality_input
            self._root.children[2].children[4] = self._pool

    def properties(self):
        """Returns a tuple containing all the information about the custom counter line.

        In order, returns:
            id of the plot
            collection object or None
            countername of the line
            instance
        """
        plot_id = int(self._to_plot.value.split()[1])

        countername = self._countername_autocomplete.value
        if not self._countername_autocomplete.value:
            countername = self._countername_autocomplete.value_input

        pool = None
        locality = "0"
        if self._selected_collection:
            locality = self._locality_select.value
            if self._pool_select.value != "No pool":
                pool = self._pool_select.value
        else:
            locality = self._locality_input.value
            if self._pool.value:
                pool = self._pool.value

        is_total = True
        if self._is_total.active == 1:
            is_total = False

        worker_id = None
        if is_total:
            worker_id = "total"
        else:
            worker_id = self._thread_id.value

        instance = format_instance(locality, pool, worker_id)

        return plot_id, self._selected_collection, countername, instance, self._name

    def set_properties(self, plot_id, collection, countername, locality_id,
                       pool, thread_id, name):
        """Sets the properties of the widget from the arguments"""
        if plot_id in self._to_plot.options:
            self._to_plot.value = plot_id

        self._set_collection(collection)
        self._countername_autocomplete.value = countername

        if locality_id in self._locality_select.options:
            self._locality_select.value = locality_id
        self._locality_input.value = locality_id

        if thread_id == "total":
            self._change_is_total(None, None, 0)
            self._is_total.active = 0
        else:
            self._thread_id.value = thread_id
            self._change_is_total(None, None, 1)
            self._is_total.active = 1

        if pool in self._pool_select.options:
            self._pool_select = pool
        self._pool.value = pool

        self._change_name(None, None, name)
        self._name_edit.value = name

    def set_plots(self, plots):
        self._to_plot.options = plots
        if self._to_plot.value not in plots:
            self._to_plot.value = plots[0]
Exemplo n.º 22
0
h_slider = Slider(title="spatial meshwidth", name='spatial meshwidth', value=pde_settings.h_init, start=pde_settings.h_min,
                  end=pde_settings.h_max, step=pde_settings.h_step)
h_slider.on_change('value', mesh_change)
# slider controlling spatial stepsize of the solver
k_slider = Slider(title="temporal meshwidth", name='temporal meshwidth', value=pde_settings.k_init, start=pde_settings.k_min,
                  end=pde_settings.k_max, step=pde_settings.k_step)
k_slider.on_change('value', mesh_change)
# radiobuttons controlling pde type
pde_type = RadioButtonGroup(labels=['Heat', 'Wave'], active=0)
pde_type.on_change('active', pde_type_change)
# radiobuttons controlling solver type
solver_type = RadioButtonGroup(labels=['Explicit', 'Implicit'], active=0)
solver_type.on_change('active', mesh_change)
# text input for IC
initial_condition = TextInput(value=pde_settings.IC_init, title="initial condition")
initial_condition.on_change('value', initial_condition_change)

# initialize plot
toolset = "crosshair,pan,reset,resize,wheel_zoom,box_zoom"
# Generate a figure container
plot = Figure(plot_height=400,
              plot_width=400,
              tools=toolset,
              title="Time dependent PDEs",
              x_range=[pde_settings.x_min, pde_settings.x_max],
              y_range=[-1, 1]
              )

# Plot the numerical solution at time=t by the x,u values in the source property
plot.line('x', 'u', source=plot_data_num,
          line_width=.5,
Exemplo n.º 23
0
def fun_change(attrname, old, new):
    f_str = f_input.value
    f_fun, f_sym = string_to_function_parser(f_str, ['x'])
    print der.value
    df_sym = diff(f_sym, 'x', int(der.value))
    df_fun = sym_to_function_parser(df_sym,['x'])
    x = np.linspace(-5, 5, 100)
    y = f_fun(x)
    dy = df_fun(x)

    line_source.data = dict(x=x, y=y, dy=dy)

def init_data():
    fun_change(None,None,None)

# Plotting
plot = Figure(title="function plotter",
              x_range=[-5,5],
              y_range=[-5,5])
plot.line(x='x', y='y', source=line_source, color='red', legend='f(x)')
plot.line(x='x', y='dy', source=line_source, color='blue', legend='df^n(x)')

#Callback
f_input.on_change('value', fun_change)
derivative_input.on_change('value', fun_change)

init_data()

#Layout
curdoc().add_root(row(plot,column(f_input,derivative_input)))
Exemplo n.º 24
0
 def create_title_input(self) -> TextInput:
     title = self.legend.title
     title_input = TextInput(title="Title", value=title, width=210)
     title_input.on_change("value", self.handle_title_change)
     title_input.trigger("value", title, title)
     return title_input
Exemplo n.º 25
0

def create_labels(new):
    words = new.split(' ')
    x = np.linspace(0.035, 0.96, num=len(words))
    y = np.repeat(0.5, len(words))
    data_dict = {'x': x, 'y': y, 'text': words}
    return (ColumnDataSource(data=data_dict))


def my_text_input_handler(attr, old, new):
    print("Previous label: " + old)
    print("Updated label: " + new)
    src = create_source(new)
    labs = create_labels(new)
    source.data.update(src.data)
    labels.data.update(labs.data)
    #print(source.data)


source = create_source('turn left')
labels = create_labels('turn left')

text_input = TextInput(value="turn left", title="New SCAN Command:")
text_input.on_change("value", my_text_input_handler)

plot = create_plot(source, labels)

layout = layout([[widgetbox(text_input)], [plot]])

curdoc().add_root(layout)
Exemplo n.º 26
0
u_input = TextInput(value=odesystem_settings.sample_system_functions[
    odesystem_settings.init_fun_key][0],
                    title="u(x,y):")
v_input = TextInput(value=odesystem_settings.sample_system_functions[
    odesystem_settings.init_fun_key][1],
                    title="v(x,y):")

# dropdown menu for selecting one of the sample functions
sample_fun_input = Dropdown(
    label="choose a sample function pair or enter one below",
    menu=odesystem_settings.sample_system_names)

# Interactor for entering starting point of initial condition
interactor = my_bokeh_utils.Interactor(plot)

# initialize callback behaviour
sample_fun_input.on_click(sample_fun_change)
u_input.on_change('value', ode_change)
v_input.on_change('value', ode_change)
interactor.on_click(initial_value_change)

# calculate data
init_data()

# lists all the controls in our app associated with the default_funs panel
function_controls = widgetbox(sample_fun_input, u_input, v_input, width=400)

# refresh quiver field and streamline all 100ms
curdoc().add_periodic_callback(refresh_user_view, 100)
# make layout
curdoc().add_root(row(function_controls, plot))
Exemplo n.º 27
0
plot.legend.location = "top_left"

# Set up widgets
text = TextInput(title="title", value='UMAP Text clustering')

umap1 = Slider(title="NNeighbours", value=10, start=0, end=100, step=10)
umap2 = Slider(title="Min Dist.", value=0.5, start=0, end=1, step=0.1)
clusters = Slider(title="Clusters", value=1, start=1, end=100, step=1)


# Set up callbacks
def update_title(attrname, old, new):
    plot.title.text = text.value


text.on_change('value', update_title)


def update_data(attrname, old, new):  # Update Function KMeans Clustering

    cc = clusters.value
    kmeans = KMeans(n_clusters=cc,
                    init='k-means++').fit(df.loc[::, ['umap_x', 'umap_y']])
    df.loc[::, 'topic_umap_km'] = kmeans.predict(df.loc[::,
                                                        ['umap_x', 'umap_y']])

    fact = list(map(str, sorted(list((df.topic_umap_km.unique())))))
    col_dict = dict(zip(fact, viridis(len(fact))))

    x = []
    for i in df.index:
Exemplo n.º 28
0
class Dashboard:
    """Explorepy dashboard class"""

    def __init__(self, explore=None, mode='signal'):
        """
        Args:
            stream_processor (explorepy.stream_processor.StreamProcessor): Stream processor object
        """
        logger.debug(f"Initializing dashboard in {mode} mode")
        self.explore = explore
        self.stream_processor = self.explore.stream_processor
        self.n_chan = self.stream_processor.device_info['adc_mask'].count(1)
        self.y_unit = DEFAULT_SCALE
        self.offsets = np.arange(1, self.n_chan + 1)[:, np.newaxis].astype(float)
        self.chan_key_list = [CHAN_LIST[i]
                              for i, mask in enumerate(reversed(self.stream_processor.device_info['adc_mask'])) if
                              mask == 1]
        self.exg_mode = 'EEG'
        self.rr_estimator = None
        self.win_length = WIN_LENGTH
        self.mode = mode
        self.exg_fs = self.stream_processor.device_info['sampling_rate']
        self._vis_time_offset = None
        self._baseline_corrector = {"MA_length": 1.5 * EXG_VIS_SRATE,
                                    "baseline": 0}

        # Init ExG data source
        exg_temp = np.zeros((self.n_chan, 2))
        exg_temp[:, 0] = self.offsets[:, 0]
        exg_temp[:, 1] = np.nan
        init_data = dict(zip(self.chan_key_list, exg_temp))
        self._exg_source_orig = ColumnDataSource(data=init_data)
        init_data['t'] = np.array([0., 0.])
        self._exg_source_ds = ColumnDataSource(data=init_data)  # Downsampled ExG data for visualization purposes

        # Init ECG R-peak source
        init_data = dict(zip(['r_peak', 't'], [np.array([None], dtype=np.double), np.array([None], dtype=np.double)]))
        self._r_peak_source = ColumnDataSource(data=init_data)

        # Init marker source
        init_data = dict(zip(['marker', 't'], [np.array([None], dtype=np.double), np.array([None], dtype=np.double)]))
        self._marker_source = ColumnDataSource(data=init_data)

        # Init ORN data source
        init_data = dict(zip(ORN_LIST, np.zeros((9, 1))))
        init_data['t'] = [0.]
        self._orn_source = ColumnDataSource(data=init_data)

        # Init table sources
        self._heart_rate_source = ColumnDataSource(data={'heart_rate': ['NA']})
        self._firmware_source = ColumnDataSource(
            data={'firmware_version': [self.stream_processor.device_info['firmware_version']]}
        )
        self._battery_source = ColumnDataSource(data={'battery': ['NA']})
        self.temperature_source = ColumnDataSource(data={'temperature': ['NA']})
        self.light_source = ColumnDataSource(data={'light': ['NA']})
        self.battery_percent_list = []
        self.server = None

        # Init fft data source
        init_data = dict(zip(self.chan_key_list, np.zeros((self.n_chan, 1))))
        init_data['f'] = np.array([0.])
        self.fft_source = ColumnDataSource(data=init_data)

        # Init impedance measurement source
        init_data = {'channel':   self.chan_key_list,
                     'impedance': ['NA' for i in range(self.n_chan)],
                     'row':       ['1' for i in range(self.n_chan)],
                     'color':     ['black' for i in range(self.n_chan)]}
        self.imp_source = ColumnDataSource(data=init_data)

        # Init timer source
        self._timer_source = ColumnDataSource(data={'timer': ['00:00:00']})

    def start_server(self):
        """Start bokeh server"""
        validate(False)
        logger.debug("Starting bokeh server...")
        port_number = find_free_port()
        logger.info("Opening the dashboard on port: %i", port_number)
        self.server = Server({'/': self._init_doc}, num_procs=1, port=port_number)
        self.server.start()

    def start_loop(self):
        """Start io loop and show the dashboard"""
        logger.debug("Starting bokeh io_loop...")
        self.server.io_loop.add_callback(self.server.show, "/")
        try:
            self.server.io_loop.start()
        except KeyboardInterrupt:
            if self.mode == 'signal':
                logger.info("Got Keyboard Interrupt. The program exits ...")
                self.explore.stop_lsl()
                self.explore.stop_recording()
                os._exit(0)
            else:
                logger.info("Got Keyboard Interrupt. The program exits after disabling the impedance mode ...")
                raise KeyboardInterrupt

    def exg_callback(self, packet):
        """
        Update ExG data in the visualization

        Args:
            packet (explorepy.packet.EEG): Received ExG packet

        """
        time_vector, exg = packet.get_data(self.exg_fs)
        if self._vis_time_offset is None:
            self._vis_time_offset = time_vector[0]
        time_vector -= self._vis_time_offset
        self._exg_source_orig.stream(dict(zip(self.chan_key_list, exg)), rollover=int(self.exg_fs * self.win_length))

        if self.mode == 'signal':
            # Downsampling
            exg = exg[:, ::int(self.exg_fs / EXG_VIS_SRATE)]
            time_vector = time_vector[::int(self.exg_fs / EXG_VIS_SRATE)]

            # Baseline correction
            if self.baseline_widget.active:
                samples_avg = exg.mean(axis=1)
                if self._baseline_corrector["baseline"] is None:
                    self._baseline_corrector["baseline"] = samples_avg
                else:
                    self._baseline_corrector["baseline"] -= (
                            (self._baseline_corrector["baseline"] - samples_avg) / self._baseline_corrector["MA_length"] *
                            exg.shape[1])
                exg -= self._baseline_corrector["baseline"][:, np.newaxis]
            else:
                self._baseline_corrector["baseline"] = None

            # Update ExG unit
            exg = self.offsets + exg / self.y_unit
            new_data = dict(zip(self.chan_key_list, exg))
            new_data['t'] = time_vector
            self.doc.add_next_tick_callback(partial(self._update_exg, new_data=new_data))

    def orn_callback(self, packet):
        """Update orientation data

        Args:
            packet (explorepy.packet.Orientation): Orientation packet
        """
        if self.tabs.active != 1:
            return
        timestamp, orn_data = packet.get_data()
        if self._vis_time_offset is None:
            self._vis_time_offset = timestamp[0]
        timestamp -= self._vis_time_offset
        new_data = dict(zip(ORN_LIST, np.array(orn_data)[:, np.newaxis]))
        new_data['t'] = timestamp
        self.doc.add_next_tick_callback(partial(self._update_orn, new_data=new_data))

    def info_callback(self, packet):
        """Update device information in the dashboard

        Args:
            packet (explorepy.packet.Environment): Environment/DeviceInfo packet

        """
        new_info = packet.get_data()
        for key in new_info.keys():
            data = {key: new_info[key]}
            if key == 'firmware_version':
                self.doc.add_next_tick_callback(partial(self._update_fw_version, new_data=data))
            elif key == 'battery':
                self.battery_percent_list.append(new_info[key][0])
                if len(self.battery_percent_list) > BATTERY_N_MOVING_AVERAGE:
                    del self.battery_percent_list[0]
                value = int(np.mean(self.battery_percent_list) / 5) * 5
                if value < 1:
                    value = 1
                self.doc.add_next_tick_callback(partial(self._update_battery, new_data={key: [value]}))
            elif key == 'temperature':
                self.doc.add_next_tick_callback(partial(self._update_temperature, new_data=data))
            elif key == 'light':
                data[key] = [int(data[key][0])]
                self.doc.add_next_tick_callback(partial(self._update_light, new_data=data))
            else:
                logger.warning("There is no field named: " + key)

    def marker_callback(self, packet):
        """Update markers
        Args:
            packet (explorepy.packet.EventMarker): Event marker packet
        """
        if self.mode == "impedance":
            return
        timestamp, _ = packet.get_data()
        if self._vis_time_offset is None:
            self._vis_time_offset = timestamp[0]
        timestamp -= self._vis_time_offset
        new_data = dict(zip(['marker', 't', 'code'], [np.array([0.01, self.n_chan + 0.99, None], dtype=np.double),
                                                      np.array([timestamp[0], timestamp[0], None], dtype=np.double)]))
        self.doc.add_next_tick_callback(partial(self._update_marker, new_data=new_data))

    def impedance_callback(self, packet):
        """Update impedances

        Args:
             packet (explorepy.packet.EEG): ExG packet
        """
        if self.mode == "impedance":
            imp = packet.get_impedances()
            color = []
            imp_status = []
            for value in imp:
                if value > 500:
                    color.append("black")
                    imp_status.append("Open")
                elif value > 100:
                    color.append("red")
                    imp_status.append(str(round(value, 0)) + " K\u03A9")
                elif value > 50:
                    color.append("orange")
                    imp_status.append(str(round(value, 0)) + " K\u03A9")
                elif value > 10:
                    color.append("yellow")
                    imp_status.append(str(round(value, 0)) + " K\u03A9")
                elif value > 5:
                    imp_status.append(str(round(value, 0)) + " K\u03A9")
                    color.append("green")
                else:
                    color.append("green")
                    imp_status.append("<5K\u03A9")  # As the ADS is not precise in low values.

            data = {"impedance": imp_status,
                    'channel':   self.chan_key_list,
                    'row':       ['1' for i in range(self.n_chan)],
                    'color':     color
                    }
            self.doc.add_next_tick_callback(partial(self._update_imp, new_data=data))
        else:
            raise RuntimeError("Trying to compute impedances while the dashboard is not in Impedance mode!")

    @gen.coroutine
    @without_property_validation
    def _update_exg(self, new_data):
        self._exg_source_ds.stream(new_data, rollover=int(2 * EXG_VIS_SRATE * WIN_LENGTH))

    @gen.coroutine
    @without_property_validation
    def _update_orn(self, new_data):
        self._orn_source.stream(new_data, rollover=int(2 * WIN_LENGTH * ORN_SRATE))

    @gen.coroutine
    @without_property_validation
    def _update_fw_version(self, new_data):
        self._firmware_source.stream(new_data, rollover=1)

    @gen.coroutine
    @without_property_validation
    def _update_battery(self, new_data):
        self._battery_source.stream(new_data, rollover=1)

    @gen.coroutine
    @without_property_validation
    def _update_temperature(self, new_data):
        self.temperature_source.stream(new_data, rollover=1)

    @gen.coroutine
    @without_property_validation
    def _update_light(self, new_data):
        self.light_source.stream(new_data, rollover=1)

    @gen.coroutine
    @without_property_validation
    def _update_marker(self, new_data):
        self._marker_source.stream(new_data=new_data, rollover=100)

    @gen.coroutine
    @without_property_validation
    def _update_imp(self, new_data):
        self.imp_source.stream(new_data, rollover=self.n_chan)

    @gen.coroutine
    @without_property_validation
    def _update_fft(self):
        """ Update spectral frequency analysis plot"""
        # Check if the tab is active and if EEG mode is active
        if (self.tabs.active != 2) or (self.exg_mode != 'EEG'):
            return

        exg_data = np.array([self._exg_source_orig.data[key] for key in self.chan_key_list])

        if exg_data.shape[1] < self.exg_fs * 5:
            return
        fft_content, freq = get_fft(exg_data, self.exg_fs)
        data = dict(zip(self.chan_key_list, fft_content))
        data['f'] = freq
        self.fft_source.data = data

    @gen.coroutine
    @without_property_validation
    def _update_heart_rate(self):
        """Detect R-peaks and update the plot and heart rate"""
        if self.exg_mode == 'EEG':
            self._heart_rate_source.stream({'heart_rate': ['NA']}, rollover=1)
            return
        if CHAN_LIST[0] not in self.chan_key_list:
            logger.warning('Heart rate estimation works only when channel 1 is enabled.')
            return
        if self.rr_estimator is None:
            self.rr_estimator = HeartRateEstimator(fs=self.exg_fs)
            # Init R-peaks plot
            self.exg_plot.circle(x='t', y='r_peak', source=self._r_peak_source,
                                 fill_color="red", size=8)

        ecg_data = (np.array(self._exg_source_ds.data['Ch1'])[-2 * EXG_VIS_SRATE:] - self.offsets[0]) * self.y_unit
        time_vector = np.array(self._exg_source_ds.data['t'])[-2 * EXG_VIS_SRATE:]

        # Check if the peak2peak value is bigger than threshold
        if (np.ptp(ecg_data) < V_TH[0]) or (np.ptp(ecg_data) > V_TH[1]):
            logger.warning("P2P value larger or less than threshold. Cannot compute heart rate!")
            return

        peaks_time, peaks_val = self.rr_estimator.estimate(ecg_data, time_vector)
        peaks_val = (np.array(peaks_val) / self.y_unit) + self.offsets[0]
        if peaks_time:
            data = dict(zip(['r_peak', 't'], [peaks_val, peaks_time]))
            self._r_peak_source.stream(data, rollover=50)

        # Update heart rate cell
        estimated_heart_rate = self.rr_estimator.heart_rate
        data = {'heart_rate': [estimated_heart_rate]}
        self._heart_rate_source.stream(data, rollover=1)

    @gen.coroutine
    @without_property_validation
    def _change_scale(self, attr, old, new):
        """Change y-scale of ExG plot"""
        logger.debug(f"ExG scale has been changed from {old} to {new}")
        new, old = SCALE_MENU[new], SCALE_MENU[old]
        old_unit = 10 ** (-old)
        self.y_unit = 10 ** (-new)

        for chan, value in self._exg_source_ds.data.items():
            if chan in self.chan_key_list:
                temp_offset = self.offsets[self.chan_key_list.index(chan)]
                self._exg_source_ds.data[chan] = (value - temp_offset) * (old_unit / self.y_unit) + temp_offset
        self._r_peak_source.data['r_peak'] = (np.array(self._r_peak_source.data['r_peak']) - self.offsets[0]) * \
                                             (old_unit / self.y_unit) + self.offsets[0]

    @gen.coroutine
    @without_property_validation
    def _change_t_range(self, attr, old, new):
        """Change time range"""
        logger.debug(f"Time scale has been changed from {old} to {new}")
        self._set_t_range(TIME_RANGE_MENU[new])

    @gen.coroutine
    def _change_mode(self, attr, old, new):
        """Set EEG or ECG mode"""
        logger.debug(f"ExG mode has been changed to {new}")
        self.exg_mode = new

    def _init_doc(self, doc):
        self.doc = doc
        self.doc.title = "Explore Dashboard"
        with open(os.path.join(os.path.dirname(__file__), 'templates', 'index.html')) as f:
            index_template = Template(f.read())
        doc.template = index_template
        self.doc.theme = Theme(os.path.join(os.path.dirname(__file__), 'theme.yaml'))
        self._init_plots()
        m_widgetbox = self._init_controls()

        # Create tabs
        if self.mode == "signal":
            exg_tab = Panel(child=self.exg_plot, title="ExG Signal")
            orn_tab = Panel(child=column([self.acc_plot, self.gyro_plot, self.mag_plot], sizing_mode='scale_width'),
                            title="Orientation")
            fft_tab = Panel(child=self.fft_plot, title="Spectral analysis")
            self.tabs = Tabs(tabs=[exg_tab, orn_tab, fft_tab], width=400, sizing_mode='scale_width')
            self.recorder_widget = self._init_recorder()
            self.push2lsl_widget = self._init_push2lsl()
            self.set_marker_widget = self._init_set_marker()
            self.baseline_widget = CheckboxGroup(labels=['Baseline correction'], active=[0])

        elif self.mode == "impedance":
            imp_tab = Panel(child=self.imp_plot, title="Impedance")
            self.tabs = Tabs(tabs=[imp_tab], width=500, sizing_mode='scale_width')
        banner = Div(text=""" <a href="https://www.mentalab.com"><img src=
        "https://images.squarespace-cdn.com/content/5428308ae4b0701411ea8aaf/1505653866447-R24N86G5X1HFZCD7KBWS/
        Mentalab%2C+Name+copy.png?format=1500w&content-type=image%2Fpng" alt="Mentalab"  width="225" height="39">""",
                     width=1500, height=50, css_classes=["banner"], align='center', sizing_mode="stretch_width")
        heading = Div(text=""" """, height=2, sizing_mode="stretch_width")
        if self.mode == 'signal':
            layout = column([heading,
                             banner,
                             row(m_widgetbox,
                                 Spacer(width=10, height=300),
                                 self.tabs,
                                 Spacer(width=10, height=300),
                                 column(Spacer(width=170, height=50), self.baseline_widget, self.recorder_widget,
                                        self.set_marker_widget, self.push2lsl_widget),
                                 Spacer(width=50, height=300)),
                             ],
                            sizing_mode="stretch_both")

        elif self.mode == 'impedance':
            layout = column(banner,
                            Spacer(width=600, height=20),
                            row([m_widgetbox, Spacer(width=25, height=500), self.tabs])
                            )
        self.doc.add_root(layout)
        self.doc.add_periodic_callback(self._update_fft, 2000)
        self.doc.add_periodic_callback(self._update_heart_rate, 2000)
        if self.stream_processor:
            self.stream_processor.subscribe(topic=TOPICS.filtered_ExG, callback=self.exg_callback)
            self.stream_processor.subscribe(topic=TOPICS.raw_orn, callback=self.orn_callback)
            self.stream_processor.subscribe(topic=TOPICS.device_info, callback=self.info_callback)
            self.stream_processor.subscribe(topic=TOPICS.marker, callback=self.marker_callback)
            self.stream_processor.subscribe(topic=TOPICS.env, callback=self.info_callback)
            self.stream_processor.subscribe(topic=TOPICS.imp, callback=self.impedance_callback)

    def _init_plots(self):
        """Initialize all plots in the dashboard"""
        self.exg_plot = figure(y_range=(0.01, self.n_chan + 1 - 0.01), y_axis_label='Voltage', x_axis_label='Time (s)',
                               title="ExG signal",
                               plot_height=250, plot_width=500,
                               y_minor_ticks=int(10),
                               tools=[ResetTool()], active_scroll=None, active_drag=None,
                               active_inspect=None, active_tap=None, sizing_mode="scale_width")

        self.mag_plot = figure(y_axis_label='Mag [mgauss/LSB]', x_axis_label='Time (s)',
                               plot_height=100, plot_width=500,
                               tools=[ResetTool()], active_scroll=None, active_drag=None,
                               active_inspect=None, active_tap=None, sizing_mode="scale_width")
        self.acc_plot = figure(y_axis_label='Acc [mg/LSB]',
                               plot_height=75, plot_width=500,
                               tools=[ResetTool()], active_scroll=None, active_drag=None,
                               active_inspect=None, active_tap=None, sizing_mode="scale_width")
        self.acc_plot.xaxis.visible = False
        self.gyro_plot = figure(y_axis_label='Gyro [mdps/LSB]',
                                plot_height=75, plot_width=500,
                                tools=[ResetTool()], active_scroll=None, active_drag=None,
                                active_inspect=None, active_tap=None, sizing_mode="scale_width")
        self.gyro_plot.xaxis.visible = False

        self.fft_plot = figure(y_axis_label='Amplitude (uV)', x_axis_label='Frequency (Hz)', title="FFT",
                               x_range=(0, 70), plot_height=250, plot_width=500, y_axis_type="log",
                               tools=[BoxZoomTool(), ResetTool()], active_scroll=None, active_drag=None,
                               active_tap=None,
                               sizing_mode="scale_width")

        self.imp_plot = self._init_imp_plot()

        # Set yaxis properties
        self.exg_plot.yaxis.ticker = SingleIntervalTicker(interval=1, num_minor_ticks=0)

        # Initial plot line
        for i in range(self.n_chan):
            self.exg_plot.line(x='t', y=self.chan_key_list[i], source=self._exg_source_ds,
                               line_width=1.0, alpha=.9, line_color="#42C4F7")
            self.fft_plot.line(x='f', y=self.chan_key_list[i], source=self.fft_source,
                               legend_label=self.chan_key_list[i] + " ",
                               line_width=1.5, alpha=.9, line_color=FFT_COLORS[i])
        self.fft_plot.yaxis.axis_label_text_font_style = 'normal'
        self.exg_plot.line(x='t', y='marker', source=self._marker_source,
                           line_width=1, alpha=.8, line_color='#7AB904', line_dash="4 4")

        for i in range(3):
            self.acc_plot.line(x='t', y=ORN_LIST[i], source=self._orn_source, legend_label=ORN_LIST[i] + " ",
                               line_width=1.5, line_color=LINE_COLORS[i], alpha=.9)
            self.gyro_plot.line(x='t', y=ORN_LIST[i + 3], source=self._orn_source, legend_label=ORN_LIST[i + 3] + " ",
                                line_width=1.5, line_color=LINE_COLORS[i], alpha=.9)
            self.mag_plot.line(x='t', y=ORN_LIST[i + 6], source=self._orn_source, legend_label=ORN_LIST[i + 6] + " ",
                               line_width=1.5, line_color=LINE_COLORS[i], alpha=.9)

        # Set x_range
        self.plot_list = [self.exg_plot, self.acc_plot, self.gyro_plot, self.mag_plot]
        self._set_t_range(WIN_LENGTH)

        # Set the formatting of yaxis ticks' labels
        self.exg_plot.yaxis.major_label_overrides = dict(zip(range(1, self.n_chan + 1), self.chan_key_list))
        for plot in self.plot_list:
            plot.toolbar.autohide = True
            plot.yaxis.axis_label_text_font_style = 'normal'
            if len(plot.legend) != 0:
                plot.legend.location = "bottom_left"
                plot.legend.orientation = "horizontal"
                plot.legend.padding = 2

    def _init_imp_plot(self):
        plot = figure(plot_width=600, plot_height=200, x_range=self.chan_key_list[0:self.n_chan],
                      y_range=[str(1)], toolbar_location=None, sizing_mode="scale_width")

        plot.circle(x='channel', y="row", size=50, source=self.imp_source, fill_alpha=0.6, color="color",
                    line_color='color', line_width=2)

        text_props = {"source":          self.imp_source, "text_align": "center",
                      "text_color":      "white", "text_baseline": "middle", "text_font": "helvetica",
                      "text_font_style": "bold"}

        x = dodge("channel", -0.1, range=plot.x_range)

        plot.text(x=x, y=dodge('row', -.35, range=plot.y_range),
                  text="impedance", **text_props).glyph.text_font_size = "10pt"
        plot.text(x=x, y=dodge('row', -.25, range=plot.y_range), text="channel",
                  **text_props).glyph.text_font_size = "12pt"

        plot.outline_line_color = None
        plot.grid.grid_line_color = None
        plot.axis.axis_line_color = None
        plot.axis.major_tick_line_color = None
        plot.axis.major_label_standoff = 0
        plot.axis.visible = False
        return plot

    def _init_controls(self):
        """Initialize all controls in the dashboard"""
        # EEG/ECG Radio button
        self.mode_control = widgets.Select(title="Signal", value='EEG', options=MODE_LIST, width=170, height=50)
        self.mode_control.on_change('value', self._change_mode)

        self.t_range = widgets.Select(title="Time window", value="10 s", options=list(TIME_RANGE_MENU.keys()),
                                      width=170, height=50)
        self.t_range.on_change('value', self._change_t_range)
        self.y_scale = widgets.Select(title="Y-axis Scale", value="1 mV", options=list(SCALE_MENU.keys()),
                                      width=170, height=50)
        self.y_scale.on_change('value', self._change_scale)

        # Create device info tables
        columns = [widgets.TableColumn(field='heart_rate', title="Heart Rate (bpm)")]
        self.heart_rate = widgets.DataTable(source=self._heart_rate_source, index_position=None, sortable=False,
                                            reorderable=False,
                                            columns=columns, width=170, height=50)

        columns = [widgets.TableColumn(field='firmware_version', title="Firmware Version")]
        self.firmware = widgets.DataTable(source=self._firmware_source, index_position=None, sortable=False,
                                          reorderable=False,
                                          columns=columns, width=170, height=50)

        columns = [widgets.TableColumn(field='battery', title="Battery (%)")]
        self.battery = widgets.DataTable(source=self._battery_source, index_position=None, sortable=False,
                                         reorderable=False,
                                         columns=columns, width=170, height=50)

        columns = [widgets.TableColumn(field='temperature', title="Device temperature (C)")]
        self.temperature = widgets.DataTable(source=self.temperature_source, index_position=None, sortable=False,
                                             reorderable=False, columns=columns, width=170, height=50)

        columns = [widgets.TableColumn(field='light', title="Light (Lux)")]
        self.light = widgets.DataTable(source=self.light_source, index_position=None, sortable=False, reorderable=False,
                                       columns=columns, width=170, height=50)
        if self.mode == 'signal':
            widget_list = [Spacer(width=170, height=30), self.mode_control, self.y_scale, self.t_range, self.heart_rate,
                           self.battery, self.temperature, self.firmware]
        elif self.mode == 'impedance':
            widget_list = [Spacer(width=170, height=40), self.battery, self.temperature, self.firmware]

        widget_box = widgetbox(widget_list, width=175, height=450, sizing_mode='fixed')
        return widget_box

    def _init_recorder(self):
        self.rec_button = Toggle(label=u"\u25CF  Record", button_type="default", active=False,
                                 width=170, height=35)
        self.file_name_widget = TextInput(value="test_file", title="File name:", width=170, height=50)
        self.file_type_widget = RadioGroup(labels=["EDF (BDF+)", "CSV"], active=0, width=170, height=50)

        columns = [widgets.TableColumn(field='timer', title="Record time",
                                       formatter=widgets.StringFormatter(text_align='center'))]
        self.timer = widgets.DataTable(source=self._timer_source, index_position=None, sortable=False,
                                       reorderable=False,
                                       header_row=False, columns=columns,
                                       width=170, height=50, css_classes=["timer_widget"])

        self.rec_button.on_click(self._toggle_rec)
        return column([Spacer(width=170, height=5), self.file_name_widget, self.file_type_widget, self.rec_button,
                      self.timer], width=170, height=200, sizing_mode='fixed')

    def _toggle_rec(self, active):
        logger.debug(f"Pressed record button -> {active}")
        if active:
            self.event_code_input.disabled = False
            self.marker_button.disabled = False
            if self.explore.is_connected:
                self.explore.record_data(file_name=self.file_name_widget.value,
                                         file_type=['edf', 'csv'][self.file_type_widget.active],
                                         do_overwrite=True)
                self.rec_button.label = u"\u25A0  Stop"
                self.rec_start_time = datetime.now()
                self.rec_timer_id = self.doc.add_periodic_callback(self._timer_callback, 1000)
            else:
                self.rec_button.active = False
                self.doc.remove_periodic_callback(self.rec_timer_id)
                self.doc.add_next_tick_callback(partial(self._update_rec_timer, new_data={'timer': '00:00:00'}))
        else:
            self.explore.stop_recording()
            self.rec_button.label = u"\u25CF  Record"
            self.doc.add_next_tick_callback(partial(self._update_rec_timer, new_data={'timer': '00:00:00'}))
            self.doc.remove_periodic_callback(self.rec_timer_id)
            if not self.push2lsl_button.active:
                self.event_code_input.disabled = True
                self.marker_button.disabled = True

    def _timer_callback(self):
        t_delta = (datetime.now() - self.rec_start_time).seconds
        timer_text = ':'.join([str(int(t_delta / 3600)).zfill(2), str(int(t_delta / 60) % 60).zfill(2),
                               str(int(t_delta % 60)).zfill(2)])
        data = {'timer': timer_text}
        self.doc.add_next_tick_callback(partial(self._update_rec_timer, new_data=data))

    def _init_push2lsl(self):
        push2lsl_title = Div(text="""Push to LSL""", width=170, height=10)
        self.push2lsl_button = Toggle(label=u"\u25CF  Start", button_type="default", active=False,
                                      width=170, height=35)
        self.push2lsl_button.on_click(self._toggle_push2lsl)
        return column([Spacer(width=170, height=30), push2lsl_title, self.push2lsl_button],
                      width=170, height=200, sizing_mode='fixed')

    def _toggle_push2lsl(self, active):
        logger.debug(f"Pressed push2lsl button -> {active}")
        if active:
            self.event_code_input.disabled = False
            self.marker_button.disabled = False
            if self.explore.is_connected:
                self.explore.push2lsl()
                self.push2lsl_button.label = u"\u25A0  Stop"
            else:
                self.push2lsl_button.active = False
        else:
            self.explore.stop_lsl()
            self.push2lsl_button.label = u"\u25CF  Start"
            if not self.rec_button.active:
                self.event_code_input.disabled = True
                self.marker_button.disabled = True

    def _init_set_marker(self):
        self.marker_button = Button(label=u"Set", button_type="default", width=80, height=31, disabled=True)
        self.event_code_input = TextInput(value="8", title="Event code:", width=80, disabled=True)
        self.event_code_input.on_change('value', self._check_marker_value)
        self.marker_button.on_click(self._set_marker)
        return column([Spacer(width=170, height=5),
                      row([self.event_code_input,
                          column(Spacer(width=50, height=19), self.marker_button)], height=50, width=170)],
                      width=170, height=50, sizing_mode='fixed'
                      )

    def _set_marker(self):
        code = self.event_code_input.value
        self.stream_processor.set_marker(int(code))

    def _check_marker_value(self, attr, old, new):
        try:
            code = int(self.event_code_input.value)
            if code < 7 or code > 65535:
                raise ValueError('Value must be an integer between 8 and 65535')
        except ValueError:
            self.event_code_input.value = "7<val<65535"

    @gen.coroutine
    @without_property_validation
    def _update_rec_timer(self, new_data):
        self._timer_source.stream(new_data, rollover=1)

    def _set_t_range(self, t_length):
        """Change time range of ExG and orientation plots"""
        for plot in self.plot_list:
            self.win_length = int(t_length)
            plot.x_range.follow = "end"
            plot.x_range.follow_interval = t_length
            plot.x_range.range_padding = 0.
            plot.x_range.min_interval = t_length
Exemplo n.º 29
0
def main_doc(doc):
    # Frequncy Sink (line plot)
    fft_plot = pysdr.base_plot('Freq [MHz]',
                               'PSD [dB]',
                               'Frequency Sink',
                               disable_horizontal_zooming=True)
    f = (np.linspace(-sdr.sample_rate / 2.0, sdr.sample_rate / 2.0, fft_size) +
         sdr.center_freq) / 1e6
    fft_line = fft_plot.line(
        f, np.zeros(len(f)), color="aqua",
        line_width=1)  # set x values but use dummy values for y

    # Time Sink (line plot)
    time_plot = pysdr.base_plot('Time [ms]',
                                ' ',
                                'Time Sink',
                                disable_horizontal_zooming=True)
    t = np.linspace(0.0, samples_in_time_plots / sdr.sample_rate,
                    samples_in_time_plots) * 1e3  # in ms
    timeI_line = time_plot.line(
        t, np.zeros(len(t)), color="aqua",
        line_width=1)  # set x values but use dummy values for y
    timeQ_line = time_plot.line(
        t, np.zeros(len(t)), color="red",
        line_width=1)  # set x values but use dummy values for y

    # Waterfall Sink ("image" plot)
    waterfall_plot = pysdr.base_plot(' ',
                                     'Time',
                                     'Waterfall',
                                     disable_all_zooming=True)
    waterfall_plot._set_x_range(
        0, fft_size
    )  # Bokeh tries to automatically figure out range, but in this case we need to specify it
    waterfall_plot._set_y_range(0, waterfall_samples)
    waterfall_plot.axis.visible = False  # i couldn't figure out how to update x axis when freq changes, so just hide them for now
    waterfall_data = waterfall_plot.image(
        image=[shared_buffer['waterfall']],  # input has to be in list form
        x=0,  # start of x
        y=0,  # start of y
        dw=fft_size,  # size of x
        dh=waterfall_samples,  # size of y
        palette="Spectral9")  # closest thing to matlab's jet

    # IQ/Constellation Sink ("circle" plot)
    iq_plot = pysdr.base_plot(' ', ' ', 'IQ Plot')
    iq_plot._set_x_range(
        -1.0, 1.0
    )  # this is to keep it fixed at -1 to 1. you can also just zoom out with mouse wheel and it will stop auto-ranging
    iq_plot._set_y_range(-1.0, 1.0)
    iq_data = iq_plot.circle(
        np.zeros(samples_in_time_plots),
        np.zeros(samples_in_time_plots),
        line_alpha=
        0.0,  # setting line_width=0 didn't make it go away, but this works
        fill_color="aqua",
        fill_alpha=0.5,
        size=4)  # size of circles

    # Utilization bar (standard plot defined in gui.py)
    utilization_plot = pysdr.utilization_bar(
        1.0)  # sets the top at 10% instead of 100% so we can see it move
    utilization_data = utilization_plot.quad(
        top=[shared_buffer['utilization']],
        bottom=[0],
        left=[0],
        right=[1],
        color="#B3DE69")  #adds 1 rectangle

    def gain_callback(attr, old, new):
        shared_buffer[
            'stop-signal'] = True  # triggers a stop of the asynchronous read (cant change gain during it)
        time.sleep(
            0.5
        )  # give time for the stop signal to trigger it- if you get a segfault then this needs to be increased
        sdr.gain = float(new)  # set new gain
        shared_buffer['stop-signal'] = False  # turns off "stop" signal

    def freq_callback(attr, old, new):
        shared_buffer['stop-signal'] = True  # see above comments
        time.sleep(0.5)
        sdr.center_freq = float(new)  # TextInput provides a string
        f = np.linspace(-sdr.sample_rate / 2.0, sdr.sample_rate / 2.0,
                        fft_size) + sdr.center_freq
        fft_line.data_source.data['x'] = f / 1e6  # update x axis of freq sink
        shared_buffer['stop-signal'] = False

    # gain selector
    gain_select = Select(title="Gain:",
                         value=str(sdr.gain),
                         options=[str(i / 10.0) for i in sdr.get_gains()])
    gain_select.on_change('value', gain_callback)

    # center_freq TextInput
    freq_input = TextInput(value=str(sdr.center_freq),
                           title="Center Freq [Hz]")
    freq_input.on_change('value', freq_callback)

    # add the widgets to the document
    doc.add_root(row([
        widgetbox(gain_select, freq_input), utilization_plot
    ]))  # widgetbox() makes them a bit tighter grouped than column()

    # Add four plots to document, using the gridplot method of arranging them
    doc.add_root(
        gridplot([[fft_plot, time_plot], [waterfall_plot, iq_plot]],
                 sizing_mode="scale_width",
                 merge_tools=False))  # Spacer(width=20, sizing_mode="fixed")

    # This function gets called periodically, and is how the "real-time streaming mode" works
    def plot_update():
        timeI_line.data_source.data['y'] = shared_buffer[
            'i']  # send most recent I to time sink
        timeQ_line.data_source.data['y'] = shared_buffer[
            'q']  # send most recent Q to time sink
        iq_data.data_source.data['x'] = shared_buffer[
            'i']  # send most recent I to IQ
        iq_data.data_source.data['y'] = shared_buffer[
            'q']  # send most recent Q to IQ
        fft_line.data_source.data['y'] = shared_buffer[
            'psd']  # send most recent psd to freq sink
        waterfall_data.data_source.data['image'] = [
            shared_buffer['waterfall']
        ]  # send waterfall 2d array to waterfall sink
        utilization_data.data_source.data['top'] = [
            shared_buffer['utilization']
        ]  # send most recent utilization level (only need to adjust top of rectangle)

    # Add a periodic callback to be run every x milliseconds
    doc.add_periodic_callback(plot_update, 150)

    # pull out a theme from themes.py
    doc.theme = pysdr.black_and_white
Exemplo n.º 30
0
def update():
    try:
        expr = sy.sympify(text.value, dict(x=xs))
    except Exception as exception:
        errbox.text = str(exception)
    else:
        errbox.text = ""
    x, fy, ty = taylor(expr, xs, slider.value, (-2 * sy.pi, 2 * sy.pi), 200)

    p.title.text = "Taylor (n=%d) expansion comparison for: %s" % (
        slider.value, expr)
    legend.items[0].label = value("%s" % expr)
    legend.items[1].label = value("taylor(%s)" % expr)
    source.data = dict(x=x, fy=fy, ty=ty)


slider = Slider(start=1, end=20, value=1, step=1, title="Order")
slider.on_change('value', lambda attr, old, new: update())

text = TextInput(value=str(expr), title="Expression:")
text.on_change('value', lambda attr, old, new: update())

errbox = PreText()

update()

inputs = column(text, slider, errbox, width=400)

curdoc().add_root(column(inputs, p))
Exemplo n.º 31
0
class App:
    def __init__(self):
        self._sacred_config = SacredConfigFactory.local()
        self._sacred_utils = SacredUtils(self._sacred_config)
        self.state = AppState
        self.tensor_2d_plot = Tensor2DPlot()
        self.tensor_2d_plot_trace = Tensor2DPlotTrace()
        self.config_analyzer = SacredRunsConfigAnalyzer(
            self._sacred_config.create_mongo_observer())

    def _run_inference(self) -> MultiObserver:
        try:
            n_inputs = int(self.widget_text_n_inputs.value)
            n_experts = int(self.widget_text_n_experts.value)
            n_rollouts = int(self.widget_text_n_rollouts.value)

            observer = MultiObserver()
            params = Params(**self.state.sacred_reader.config)
            params.n_experts = n_experts
            params.task_size = n_inputs
            self.state.inference_config = dataclasses.asdict(params)
            agent = create_agent(params)
            self.state.sacred_reader.load_model(agent, 'agent',
                                                self.state.epoch)

            rollout_size = n_rollouts  # 15 + params.rollout_size
            run_inference(params, agent, observer, rollout_size)
            return observer

        except ValueError as e:
            print(f'ValueError: {e}')

    def update_experiment(self):
        print('Update experiment')
        try:
            # parse experiment id and load experiment from sacred
            self.widget_config_div.text = 'loading'
            self.state.experiment_id = int(
                self.widget_text_experiment_id.value)
            sr = SacredReader(self.state.experiment_id, self._sacred_config)
            self.state.sacred_reader = sr

            # update epochs select
            epochs = sr.get_epochs()
            self.state.epochs = epochs
            self.widget_button.label = f'Experiment Id: {self.state.experiment_id}'
            self.widget_epoch_select.options = list(map(str, epochs))
            self.widget_epoch_select.value = str(epochs[-1])

            # update config
            formatted_config = '<br/>'.join(
                [f'{k}: {v}' for k, v in sr.config.items()])
            self.widget_config_div.text = f'<pre>{formatted_config}</pre>'

            # update inference params
            params = Params(**sr.config)
            self.widget_text_n_rollouts.value = str(params.rollout_size)
            self.widget_text_n_experts.value = str(params.n_experts)
            self.widget_text_n_inputs.value = str(params.task_size)

            self.update_loss_plot()
            self._update()
        except Exception as e:
            print(f'Error: {e}')

    def update_loss_plot(self):
        try:
            loss_average_window = int(self.widget_loss_smooth_text.value)
            # update epochs figure
            self.widget_loss_pane.children.clear()
            self.widget_loss_pane.children.append(
                self._create_loss_figure(
                    self._sacred_utils.load_metrics([self.state.experiment_id],
                                                    loss_average_window)))
        except ValueError as e:
            print(f'Error: {e}')

    def update_epoch(self, attr, old, new):
        try:
            self.state.epoch = int(self.widget_epoch_select.value)
            print(f'Update epoch: {self.state.epoch}')
            self._update()
        except ValueError as e:
            print(f'Error: {e}')

    def _update(self):
        sr = self.state.sacred_reader
        use_training_inference_data = False
        try:
            self.widget_pane.children.clear()
            if use_training_inference_data:
                # load training data
                tensors: List[Dict[str, Tensor]] = sr.load_tensor(
                    'tensors.pt', self.state.epoch)
                self.state.tensors_data = TensorDataListDict(tensors)
            else:
                observer = self._run_inference()
                self.state.tensors_data = TensorDataMultiObserver(observer)
            viewer = TensorViewer(
                self.state.tensors_data,
                6,
                displayed_tensors=[
                    'keys_2', 'weights_2', 'attn-result_2',
                    'weights_before_softmax_2'
                ],
                rollout_on_change=self.rollout_step_on_change)
            self.widget_pane.children.append(viewer.create_layout())

            joined_data = torch.stack([
                self.state.tensors_data.tensor_by_name(rollout_step, 'keys_2')
                for rollout_step in range(self.state.tensors_data.step_count)
            ])
            self.tensor_2d_plot_trace.update(joined_data,
                                             self.state.inference_config)

            self.widget_loss_pane_inference.children.clear()
            inf_loss_data = torch.stack([
                self.state.tensors_data.tensor_by_name(rollout_step,
                                                       'error_per_inf_step')
                for rollout_step in range(self.state.tensors_data.step_count)
            ])
            self.widget_loss_pane_inference.children.append(
                self._create_loss_figure(pd.DataFrame(inf_loss_data.tolist()),
                                         width=500))

            # Plot grads
            # param_tensors = {k: v for k, v in tensors_data.tensor_map.items() if v.startswith('param')}
            # for tensor_id, tensor_name in param_tensors.items():
            #     self.widget_pane.children.append(Div(text=f'T: {tensor_name}'))
            #     fig = plot_tensor(tensors_data.tensor(2, tensor_id))
            #     self.widget_pane.children.append(fig)

        except Exception as e:
            print(f'ERR, {e}')

    def _create_loss_figure(self,
                            df: pd.DataFrame,
                            width: int = 1000,
                            height: int = 300):
        colors = itertools.cycle(palette)
        df.columns = [str(i) for i in df.columns]
        ds = ColumnDataSource(df)

        fig = Figure(y_axis_type="log", width=width, height=height)
        fig.below[0].formatter.use_scientific = False
        for column, color in zip(df.columns, colors):
            glyph = fig.line(x='index', y=column, source=ds, color=color)
            fig.add_tools(
                HoverTool(tooltips=[("epoch", "@index")] +
                          [(f"loss_{column}", f"@{column}")],
                          mode='vline',
                          renderers=[glyph]))

        def update_epoch(self: App, event):
            epoch = reduce(lambda a, b: a if a[0] < b[0] else b,
                           [(abs(e - event.x), e)
                            for e in self.state.epochs])[1]
            self.widget_epoch_select.value = str(epoch)

        fig.on_event(DoubleTap, partial(update_epoch, self))
        return fig

    def on_last_run(self):
        id = self._sacred_utils.get_last_run().id
        self.widget_text_experiment_id.value = str(id)

    def rollout_step_on_change(self, rollout_step: int):
        self.state.rollout_step = rollout_step
        tensor = self.state.tensors_data.tensor_by_name(
            self.state.rollout_step, 'keys_2')
        self.tensor_2d_plot.update(tensor, self.state.inference_config)
        # self._2d_plot_source.data = self._create_2d_plot_data()

    def run(self):

        self.widget_text_experiment_id = TextInput(value='170')
        self.widget_text_experiment_id.on_change(
            'value', lambda a, o, n: self.update_experiment())
        self.widget_button_last_run = Button(label="Last run")
        self.widget_button_last_run.on_click(self.on_last_run)
        self.widget_text_experiment_id.on_change(
            'value', lambda a, o, n: self.update_experiment())
        self.widget_epoch_select = Select(title='epoch', options=[])
        self.widget_button = Button(label="Read experiment")
        self.widget_button.on_click(self.update_experiment)
        self.widget_epoch_select.on_change('value', self.update_epoch)
        self.widget_pane = column()
        self.widget_loss_pane = column()
        self.widget_loss_pane_inference = column()
        self.widget_loss_smooth_text = TextInput(value='100')
        self.widget_loss_smooth_text.on_change(
            'value', lambda a, o, n: self.update_experiment())
        self.widget_config_div = Div(text="")
        self.widget_2d_plot = self.tensor_2d_plot.create_2d_plot()
        self.widget_2d_plot_trace = self.tensor_2d_plot_trace.create_2d_plot()

        self.widget_run_inference_button = Button(label='Run inference')
        self.widget_run_inference_button.on_click(self._update)
        self.widget_text_n_inputs = TextInput(title='n_inputs', value='5')
        self.widget_text_n_inputs.on_change('value',
                                            lambda a, o, n: self._update())
        self.widget_text_n_experts = TextInput(title='n_experts', value='2')
        self.widget_text_n_experts.on_change('value',
                                             lambda a, o, n: self._update())
        self.widget_text_n_rollouts = TextInput(title='n_rollouts', value='15')
        self.widget_text_n_rollouts.on_change('value',
                                              lambda a, o, n: self._update())

        self.widget_inference_pane = column(self.widget_text_n_inputs,
                                            self.widget_text_n_experts,
                                            self.widget_text_n_rollouts,
                                            self.widget_run_inference_button)

        curdoc().add_root(
            row(
                column(
                    row(
                        column(
                            row(Div(text="Experiment ID:"),
                                self.widget_text_experiment_id,
                                self.widget_button,
                                self.widget_button_last_run),
                            column(
                                self.widget_loss_pane,
                                row(Div(text='Smooth window:'),
                                    self.widget_loss_smooth_text))),
                        self.widget_config_div,
                    ),
                    row(
                        column(self.widget_epoch_select,
                               self.widget_inference_pane),
                        self.widget_2d_plot, self.widget_2d_plot_trace,
                        self.widget_loss_pane_inference), self.widget_pane),
                self.config_analyzer.create_layout()))
Exemplo n.º 32
0
plot.scatter('x', 'y', source=source_critical_pts, color='red', legend='critical pts')
plot.multi_line('x_ls', 'y_ls', source=source_critical_lines, color='red', legend='critical lines')

# initialize controls
# text input for input of the ode system [u,v] = [x',y']
u_input = TextInput(value=odesystem_settings.sample_system_functions[odesystem_settings.init_fun_key][0], title="u(x,y):")
v_input = TextInput(value=odesystem_settings.sample_system_functions[odesystem_settings.init_fun_key][1], title="v(x,y):")

# dropdown menu for selecting one of the sample functions
sample_fun_input = Dropdown(label="choose a sample function pair or enter one below",
                            menu=odesystem_settings.sample_system_names)

# Interactor for entering starting point of initial condition
interactor = my_bokeh_utils.Interactor(plot)

# initialize callback behaviour
sample_fun_input.on_click(sample_fun_change)
u_input.on_change('value', ode_change)
v_input.on_change('value', ode_change)
interactor.on_click(initial_value_change)

# calculate data
init_data()

# lists all the controls in our app associated with the default_funs panel
function_controls = widgetbox(sample_fun_input, u_input, v_input,width=400)

# refresh quiver field and streamline all 100ms
curdoc().add_periodic_callback(refresh_user_view, 100)
# make layout
curdoc().add_root(column(plot, function_controls))
Exemplo n.º 33
0
# the confusion matrix
text_props['text_font_size'] = "5pt"
plot.text(x=0.825, y=0.21, text=["True"], **text_props)
plot.text(x=0.925, y=0.21, text=["False"], **text_props)
plot.text(x=0.725, y=0.15, text=["True"], **text_props)
plot.text(x=0.725, y=0.05, text=["False"], **text_props)

text_props['text_font_size'] = "8pt"
plot.text(x=0.825, y=0.15, text="TP", source=conf_source, **text_props)
plot.text(x=0.925, y=0.15, text="FP", source=conf_source, **text_props)
plot.text(x=0.825, y=0.05, text="FN", source=conf_source, **text_props)
plot.text(x=0.925, y=0.05, text="TN", source=conf_source, **text_props)

update_data()

text.on_change('value', input_change)
dataurl.on_change('value', dataurl_change)

# There must be a better way:
dataurl.callback = CustomJS(args=dict(auc=auc,
                                      sample_size=sample_size),
                            code="""
         // $("label[for='"+auc.id+"']").parentNode.remove();
         document.getElementById(auc.id).parentNode.hidden = true;
         // $("label[for='"+sample_size.id+"']").parentNode.remove();
         document.getElementById(sample_size.id).parentNode.hidden = true;
    """)

for w in (threshold, text, auc, sample_size):
    w.on_change('value', input_change)
Exemplo n.º 34
0
update_is_enabled = True

# initialize controls
# dropdown menu for sample functions
function_type = Dropdown(label="choose a sample function pair or enter one below",
                         menu=convolution_settings.sample_function_names)
function_type.on_click(function_pair_input_change)

# slider controlling the evaluated x value of the convolved function
x_value_input = Slider(title="x value", name='x value', value=convolution_settings.x_value_init,
                       start=convolution_settings.x_value_min, end=convolution_settings.x_value_max,
                       step=convolution_settings.x_value_step)
x_value_input.on_change('value', input_change)
# text input for the first function to be convolved
function1_input = TextInput(value=convolution_settings.function1_input_init, title="my first function:")
function1_input.on_change('value', input_change)
# text input for the second function to be convolved
function2_input = TextInput(value=convolution_settings.function1_input_init, title="my second function:")
function2_input.on_change('value', input_change)

# initialize plot
toolset = "crosshair,pan,reset,resize,save,wheel_zoom"
# Generate a figure container
plot = Figure(plot_height=400, plot_width=400, tools=toolset,
              title="Convolution of two functions",
              x_range=[convolution_settings.x_min_view, convolution_settings.x_max_view],
              y_range=[convolution_settings.y_min_view, convolution_settings.y_max_view])

# Plot the line by the x,y values in the source property
plot.line('x', 'y', source=source_function1, line_width=3, line_alpha=0.6, color='red', legend='function 1')
plot.line('x', 'y', source=source_function2, color='green', line_width=3, line_alpha=0.6, legend='function 2')
Exemplo n.º 35
0
def modify_doc(doc):
    repo_box = TextInput(value='/project/tmorton/DM-12873/w44',
                         title='rerun',
                         css_classes=['customTextInput'])

    # Object plots
    object_hvplots = [
        renderer.get_widget(dmap, None, doc) for dmap in object_dmaps
    ]

    object_plots = [
        layout([hvplot.state], sizing_mode='fixed')
        for hvplot in object_hvplots
    ]
    object_tabs = Tabs(tabs=[
        Panel(child=plot, title=name)
        for plot, name in zip(object_plots, config['sections']['object'])
    ])
    object_panel = Panel(child=object_tabs, title='Object Catalogs')

    # Source plots
    source_categories = config['sections']['source']
    # source_hvplots = {c : renderer.get_widget(source_dmaps[c], None, doc)
    #                     for c in source_categories}

    # # source_plots = {c : layout([source_hvplots[c].state], sizing_mode='fixed')
    # #                 for c in source_categories}
    # source_plots = {c : gridplot([[None]]) for c in source_categories}
    # source_tract_select = {c : RadioButtonGroup(labels=[str(t) for t in get_tracts(butler)], active=0)
    #                             for c in source_categories}
    # source_filt_select = {c : RadioButtonGroup(labels=wide_filters, active=2)
    #                             for c in source_categories}

    # def update_source(category):
    #     def update(attr, old, new):
    #         t_sel = source_tract_select[category]
    #         f_sel = source_filt_select[category]
    #         new_tract = int(t_sel.labels[t_sel.active])
    #         new_filt = f_sel.labels[f_sel.active]
    #         logging.info('updating {} to tract={}, filt={}...'.format(category, new_tract, new_filt))
    #         dmap = get_source_dmap(butler, category, tract=new_tract, filt=new_filt)
    #         new_hvplot = renderer.get_widget(dmap, None, doc)
    #         source_plots[category].children[0] = new_hvplot.state
    #         logging.info('update complete.')
    #     return update

    # source_tab_panels = []
    # for category in source_categories:
    #     tract_select = source_tract_select[category]
    #     filt_select = source_filt_select[category]
    #     plot = source_plots[category]

    #     tract_select.on_change('active', update_source(category))
    #     filt_select.on_change('active', update_source(category))

    #     l = layout([[tract_select, filt_select], [plot]], sizing_mode='fixed')
    #     source_tab_panels.append(Panel(child=l, title=category))

    # source_tabs = Tabs(tabs=source_tab_panels)
    # source_panel = Panel(child=source_tabs, title='Source Catalogs')

    # Color plots
    color_categories = config['sections']['color']
    color_hvplots = {
        c: renderer.get_widget(color_dmaps[c], None, doc)
        for c in color_categories
    }
    color_plots = {
        c: layout([color_hvplots[c].state], sizing_mode='fixed')
        for c in color_categories
    }

    color_tabs = Tabs(
        tabs=[Panel(child=color_plots[c], title=c) for c in color_categories])
    color_panel = Panel(child=color_tabs, title='Color')

    def update_repo(attr, old, new):
        global butler
        butler = Butler(new)
        logging.info('Changing butler to {}'.format(new))

        # Update Object plots
        logging.info('Updating object plots...')
        object_dmaps = get_object_dmaps(butler=butler)

        new_object_hvplots = [
            renderer.get_widget(dmap, None, doc) for dmap in object_dmaps
        ]

        for plot, new_plot in zip(object_plots, new_object_hvplots):
            plot.children[0] = new_plot.state

        # Update Source plots
        # for c in source_categories:
        #     source_tract_select[c].labels = [str(t) for t in get_tracts(butler)]

        # # # THIS MUST BE FIXED.  PERHAPS SOURCE PLOTS SHOULD BE EMPTY UNTIL ACTIVATED
        # logging.info('Updating source plots...')
        # source_plots = {c : gridplot([[None]]) for c in source_categories}

        # source_dmaps = get_source_dmaps(butler=butler)
        # new_source_hvplots = {c : renderer.get_widget(source_dmaps[c], None, doc)
        #                       for c in source_categories}
        # for plot,new_plot in zip(source_plots, new_source_hvplots):
        #     plot.children[0] = new_plot.state

        # Update Color plots
        logging.info('Updating color plots...')
        color_dmaps = get_color_dmaps(butler=butler)
        new_color_hvplots = {
            c: renderer.get_widget(color_dmaps[c], None, doc)
            for c in color_categories
        }
        for plot, new_plot in zip(color_plots, new_color_hvplots):
            plot.children[0] = new_plot.state

    repo_box.on_change('value', update_repo)

    # uber_tabs = Tabs(tabs=[object_panel, source_panel, color_panel])
    uber_tabs = Tabs(tabs=[object_panel, color_panel])

    doc.add_root(repo_box)
    doc.add_root(uber_tabs)
    return doc
Exemplo n.º 36
0
            level = new_level
            # trigger rendering
            level_slider.value = level
        if pkg != current_rpm:
            # trigger rendering
            field.value = pkg


# Helper to get position of nodes and puttext label at adequate place
code = """
    var result = new Float64Array(xs.length)
    for (var i = 0; i < xs.length; i++) {
        result[i] = provider.graph_layout[xs[i]][%s]
    }
    return result
"""
# Add tooltips on each package tag
hover = HoverTool()
hover.tooltips = [('name', '@name'), ('version', '@version'),
                  ('link through/by', '@req')]

plot = render(current_rpm)
field.on_change('value', update)
level_slider.on_change('value', update_level)
back_bt.on_click(back)

# put the field, slider, button and plot in a layout and add to the document
layout = column(row(field, level_slider, back_bt), plot)
layout.sizing_mode = "scale_width"
curdoc().add_root(layout)
Exemplo n.º 37
0
    networkx_osm_ops.construct_poly_query(polysample, sample_name, 
                                          min_road_type='primary')
    os.system('open ' + sample_name)

button_query_sample.on_click(on_query_sample_click)
#############

#############
# Text input
text_poly = TextInput(value=poly_string, title="Polygon Coordinates:")
def on_text_value_change_poly(obj, attr, old, new):
    global poly_string
    print "New polygon coords:", new
    poly_string = new
    
text_poly.on_change('value', on_text_value_change_poly)
#############

#############
# define input directory
# Text input
text_input = TextInput(value=indir, title="Input Directory:")
def on_text_value_change_input(obj, attr, old, new):
    global indir
    print "Input Directory:", new
    indir = new
    
text_input.on_change('value', on_text_value_change_input)
#############

#############
Exemplo n.º 38
0
def create(palm):
    energy_min = palm.energy_range.min()
    energy_max = palm.energy_range.max()
    energy_npoints = palm.energy_range.size

    current_results = (0, 0, 0, 0)

    doc = curdoc()

    # Streaked and reference waveforms plot
    waveform_plot = Plot(
        title=Title(text="eTOF waveforms"),
        x_range=DataRange1d(),
        y_range=DataRange1d(),
        plot_height=PLOT_CANVAS_HEIGHT,
        plot_width=PLOT_CANVAS_WIDTH,
        toolbar_location="right",
    )

    # ---- tools
    waveform_plot.toolbar.logo = None
    waveform_plot.add_tools(PanTool(), BoxZoomTool(), WheelZoomTool(),
                            ResetTool())

    # ---- axes
    waveform_plot.add_layout(LinearAxis(axis_label="Photon energy, eV"),
                             place="below")
    waveform_plot.add_layout(LinearAxis(axis_label="Intensity",
                                        major_label_orientation="vertical"),
                             place="left")

    # ---- grid lines
    waveform_plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    waveform_plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- line glyphs
    waveform_source = ColumnDataSource(
        dict(x_str=[], y_str=[], x_ref=[], y_ref=[]))
    waveform_ref_line = waveform_plot.add_glyph(
        waveform_source, Line(x="x_ref", y="y_ref", line_color="blue"))
    waveform_str_line = waveform_plot.add_glyph(
        waveform_source, Line(x="x_str", y="y_str", line_color="red"))

    # ---- legend
    waveform_plot.add_layout(
        Legend(items=[("reference",
                       [waveform_ref_line]), ("streaked",
                                              [waveform_str_line])]))
    waveform_plot.legend.click_policy = "hide"

    # Cross-correlation plot
    xcorr_plot = Plot(
        title=Title(text="Waveforms cross-correlation"),
        x_range=DataRange1d(),
        y_range=DataRange1d(),
        plot_height=PLOT_CANVAS_HEIGHT,
        plot_width=PLOT_CANVAS_WIDTH,
        toolbar_location="right",
    )

    # ---- tools
    xcorr_plot.toolbar.logo = None
    xcorr_plot.add_tools(PanTool(), BoxZoomTool(), WheelZoomTool(),
                         ResetTool())

    # ---- axes
    xcorr_plot.add_layout(LinearAxis(axis_label="Energy shift, eV"),
                          place="below")
    xcorr_plot.add_layout(LinearAxis(axis_label="Cross-correlation",
                                     major_label_orientation="vertical"),
                          place="left")

    # ---- grid lines
    xcorr_plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    xcorr_plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- line glyphs
    xcorr_source = ColumnDataSource(dict(lags=[], xcorr1=[], xcorr2=[]))
    xcorr_plot.add_glyph(
        xcorr_source,
        Line(x="lags", y="xcorr1", line_color="purple", line_dash="dashed"))
    xcorr_plot.add_glyph(xcorr_source,
                         Line(x="lags", y="xcorr2", line_color="purple"))

    # ---- vertical span
    xcorr_center_span = Span(location=0, dimension="height")
    xcorr_plot.add_layout(xcorr_center_span)

    # Delays plot
    pulse_delay_plot = Plot(
        title=Title(text="Pulse delays"),
        x_range=DataRange1d(),
        y_range=DataRange1d(),
        plot_height=PLOT_CANVAS_HEIGHT,
        plot_width=PLOT_CANVAS_WIDTH,
        toolbar_location="right",
    )

    # ---- tools
    pulse_delay_plot.toolbar.logo = None
    pulse_delay_plot.add_tools(PanTool(), BoxZoomTool(), WheelZoomTool(),
                               ResetTool())

    # ---- axes
    pulse_delay_plot.add_layout(LinearAxis(axis_label="Pulse number"),
                                place="below")
    pulse_delay_plot.add_layout(
        LinearAxis(axis_label="Pulse delay (uncalib), eV",
                   major_label_orientation="vertical"),
        place="left",
    )

    # ---- grid lines
    pulse_delay_plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    pulse_delay_plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- line glyphs
    pulse_delay_source = ColumnDataSource(dict(pulse=[], delay=[]))
    pulse_delay_plot.add_glyph(
        pulse_delay_source, Line(x="pulse", y="delay", line_color="steelblue"))

    # ---- vertical span
    pulse_delay_plot_span = Span(location=0, dimension="height")
    pulse_delay_plot.add_layout(pulse_delay_plot_span)

    # Pulse lengths plot
    pulse_length_plot = Plot(
        title=Title(text="Pulse lengths"),
        x_range=DataRange1d(),
        y_range=DataRange1d(),
        plot_height=PLOT_CANVAS_HEIGHT,
        plot_width=PLOT_CANVAS_WIDTH,
        toolbar_location="right",
    )

    # ---- tools
    pulse_length_plot.toolbar.logo = None
    pulse_length_plot.add_tools(PanTool(), BoxZoomTool(), WheelZoomTool(),
                                ResetTool())

    # ---- axes
    pulse_length_plot.add_layout(LinearAxis(axis_label="Pulse number"),
                                 place="below")
    pulse_length_plot.add_layout(
        LinearAxis(axis_label="Pulse length (uncalib), eV",
                   major_label_orientation="vertical"),
        place="left",
    )

    # ---- grid lines
    pulse_length_plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    pulse_length_plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- line glyphs
    pulse_length_source = ColumnDataSource(dict(x=[], y=[]))
    pulse_length_plot.add_glyph(pulse_length_source,
                                Line(x="x", y="y", line_color="steelblue"))

    # ---- vertical span
    pulse_length_plot_span = Span(location=0, dimension="height")
    pulse_length_plot.add_layout(pulse_length_plot_span)

    # Folder path text input
    def path_textinput_callback(_attr, _old_value, new_value):
        save_textinput.value = new_value
        path_periodic_update()

    path_textinput = TextInput(title="Folder Path:",
                               value=os.path.join(os.path.expanduser("~")),
                               width=510)
    path_textinput.on_change("value", path_textinput_callback)

    # Saved runs dropdown menu
    def h5_update(pulse, delays, debug_data):
        prep_data, lags, corr_res_uncut, corr_results = debug_data

        waveform_source.data.update(
            x_str=palm.energy_range,
            y_str=prep_data["1"][pulse, :],
            x_ref=palm.energy_range,
            y_ref=prep_data["0"][pulse, :],
        )

        xcorr_source.data.update(lags=lags,
                                 xcorr1=corr_res_uncut[pulse, :],
                                 xcorr2=corr_results[pulse, :])

        xcorr_center_span.location = delays[pulse]
        pulse_delay_plot_span.location = pulse
        pulse_length_plot_span.location = pulse

    # this placeholder function should be reassigned in 'saved_runs_dropdown_callback'
    h5_update_fun = lambda pulse: None

    def saved_runs_dropdown_callback(_attr, _old_value, new_value):
        if new_value != "Saved Runs":
            nonlocal h5_update_fun, current_results
            saved_runs_dropdown.label = new_value
            filepath = os.path.join(path_textinput.value, new_value)
            tags, delays, lengths, debug_data = palm.process_hdf5_file(
                filepath, debug=True)
            current_results = (new_value, tags, delays, lengths)

            if autosave_checkbox.active:
                save_button_callback()

            pulse_delay_source.data.update(pulse=np.arange(len(delays)),
                                           delay=delays)
            pulse_length_source.data.update(x=np.arange(len(lengths)),
                                            y=lengths)
            h5_update_fun = partial(h5_update,
                                    delays=delays,
                                    debug_data=debug_data)

            pulse_slider.end = len(delays) - 1
            pulse_slider.value = 0
            h5_update_fun(0)

    saved_runs_dropdown = Dropdown(label="Saved Runs",
                                   button_type="primary",
                                   menu=[])
    saved_runs_dropdown.on_change("value", saved_runs_dropdown_callback)

    # ---- saved run periodic update
    def path_periodic_update():
        new_menu = []
        if os.path.isdir(path_textinput.value):
            for entry in os.scandir(path_textinput.value):
                if entry.is_file() and entry.name.endswith((".hdf5", ".h5")):
                    new_menu.append((entry.name, entry.name))
        saved_runs_dropdown.menu = sorted(new_menu, reverse=True)

    doc.add_periodic_callback(path_periodic_update, 5000)

    # Pulse number slider
    def pulse_slider_callback(_attr, _old_value, new_value):
        h5_update_fun(pulse=new_value)

    pulse_slider = Slider(
        start=0,
        end=99999,
        value=0,
        step=1,
        title="Pulse ID",
        callback_policy="throttle",
        callback_throttle=500,
    )
    pulse_slider.on_change("value", pulse_slider_callback)

    # Energy maximal range value text input
    def energy_max_spinner_callback(_attr, old_value, new_value):
        nonlocal energy_max
        if new_value > energy_min:
            energy_max = new_value
            palm.energy_range = np.linspace(energy_min, energy_max,
                                            energy_npoints)
            saved_runs_dropdown_callback("", "", saved_runs_dropdown.label)
        else:
            energy_max_spinner.value = old_value

    energy_max_spinner = Spinner(title="Maximal Energy, eV:",
                                 value=energy_max,
                                 step=0.1)
    energy_max_spinner.on_change("value", energy_max_spinner_callback)

    # Energy minimal range value text input
    def energy_min_spinner_callback(_attr, old_value, new_value):
        nonlocal energy_min
        if new_value < energy_max:
            energy_min = new_value
            palm.energy_range = np.linspace(energy_min, energy_max,
                                            energy_npoints)
            saved_runs_dropdown_callback("", "", saved_runs_dropdown.label)
        else:
            energy_min_spinner.value = old_value

    energy_min_spinner = Spinner(title="Minimal Energy, eV:",
                                 value=energy_min,
                                 step=0.1)
    energy_min_spinner.on_change("value", energy_min_spinner_callback)

    # Energy number of interpolation points text input
    def energy_npoints_spinner_callback(_attr, old_value, new_value):
        nonlocal energy_npoints
        if new_value > 1:
            energy_npoints = new_value
            palm.energy_range = np.linspace(energy_min, energy_max,
                                            energy_npoints)
            saved_runs_dropdown_callback("", "", saved_runs_dropdown.label)
        else:
            energy_npoints_spinner.value = old_value

    energy_npoints_spinner = Spinner(title="Number of interpolation points:",
                                     value=energy_npoints)
    energy_npoints_spinner.on_change("value", energy_npoints_spinner_callback)

    # Save location
    save_textinput = TextInput(title="Save Folder Path:",
                               value=os.path.join(os.path.expanduser("~")))

    # Autosave checkbox
    autosave_checkbox = CheckboxButtonGroup(labels=["Auto Save"],
                                            active=[],
                                            width=250)

    # Save button
    def save_button_callback():
        if current_results[0]:
            filename, tags, delays, lengths = current_results
            save_filename = os.path.splitext(filename)[0] + ".csv"
            df = pd.DataFrame({
                "pulse_id": tags,
                "pulse_delay": delays,
                "pulse_length": lengths
            })
            df.to_csv(os.path.join(save_textinput.value, save_filename),
                      index=False)

    save_button = Button(label="Save Results",
                         button_type="default",
                         width=250)
    save_button.on_click(save_button_callback)

    # assemble
    tab_layout = column(
        row(
            column(waveform_plot, xcorr_plot),
            Spacer(width=30),
            column(
                path_textinput,
                saved_runs_dropdown,
                pulse_slider,
                Spacer(height=30),
                energy_min_spinner,
                energy_max_spinner,
                energy_npoints_spinner,
                Spacer(height=30),
                save_textinput,
                autosave_checkbox,
                save_button,
            ),
        ),
        row(pulse_delay_plot, Spacer(width=10), pulse_length_plot),
    )

    return Panel(child=tab_layout, title="HDF5 File")
Exemplo n.º 39
0
]
p.add_layout(legend)

def update():
    try:
        expr = sy.sympify(text.value, dict(x=xs))
    except Exception as exception:
        errbox.text = str(exception)
    else:
        errbox.text = ""
    x, fy, ty = taylor(expr, xs, slider.value, (-2*sy.pi, 2*sy.pi), 200)

    p.title.text = "Taylor (n=%d) expansion comparison for: %s" % (slider.value, expr)
    legend.items[0].label = value("%s" % expr)
    legend.items[1].label = value("taylor(%s)" % expr)
    source.data = dict(x=x, fy=fy, ty=ty)

slider = Slider(start=1, end=20, value=1, step=1, title="Order")
slider.on_change('value', lambda attr, old, new: update())

text = TextInput(value=str(expr), title="Expression:")
text.on_change('value', lambda attr, old, new: update())

errbox = PreText()

update()

inputs = column(text, slider, errbox, width=400)

curdoc().add_root(column(inputs, p))
Exemplo n.º 40
0
class App:

    def __init__(self):
        self._sacred_config = SacredConfigFactory.local()
        self._sacred_utils = SacredUtils(self._sacred_config)
        self.state = AppState

    def update_experiment(self):
        print('Update experiment')
        try:
            self.widget_config_div.text = 'loading'
            self.state.experiment_id = int(self.widget_text_experiment_id.value)
            sr = SacredReader(self.state.experiment_id, self._sacred_config)
            self.state.sacred_reader = sr
            # print(f'Experiment: {self.state}')
            epochs = sr.get_epochs()
            self.state.epochs = epochs
            # print(f'Epochs: {epochs}')
            self.widget_button.label = f'Experiment Id: {self.state.experiment_id}'
            self.widget_epoch_select.options = list(map(str, epochs))
            self.widget_epoch_select.value = str(epochs[-1])

            formatted_config = '<br/>'.join([f'{k}: {v}' for k, v in sr.config.items()])
            self.widget_config_div.text = f'<pre>{formatted_config}</pre>'
            # update epochs figure
            self.widget_loss_pane.children.clear()
            self.widget_loss_pane.children.append(
                self._create_loss_figure(self._sacred_utils.load_metrics([self.state.experiment_id])))

            self._update()
        except ValueError as e:
            print(f'Error: {e}')

        # TensorViewer(observer)

    def update_epoch(self, attr, old, new):
        try:
            self.state.epoch = int(self.widget_epoch_select.value)
            print(f'Update epoch: {self.state.epoch}')
            self._update()
        except ValueError as e:
            pass

    def _update(self):
        sr = self.state.sacred_reader
        tensors: List[Dict[str, Tensor]] = sr.load_tensor('tensors.pt', self.state.epoch)
        self.widget_pane.children.clear()
        try:
            tensors_data = TensorDataListDict(tensors)
            viewer = TensorViewer(tensors_data, 6, displayed_tensors=['keys_2', 'weights_2', 'attn-result_2',
                                                                      'weights_before_softmax_2'])
            self.widget_pane.children.append(viewer.create_layout())

            # Plot grads
            # param_tensors = {k: v for k, v in tensors_data.tensor_map.items() if v.startswith('param')}
            # for tensor_id, tensor_name in param_tensors.items():
            #     self.widget_pane.children.append(Div(text=f'T: {tensor_name}'))
            #     fig = plot_tensor(tensors_data.tensor(2, tensor_id))
            #     self.widget_pane.children.append(fig)

        except Exception as e:
            print(f'ERR, {e}')

    def _create_loss_figure(self, df: pd.DataFrame):
        colors = itertools.cycle(palette)
        df.columns = [str(i) for i in df.columns]
        ds = ColumnDataSource(df)
        fig = Figure(y_axis_type="log", width=1000, height=300)
        for column, color in zip(df.columns, colors):
            glyph = fig.line(x='index', y=column, source=ds, color=color)
            fig.add_tools(
                HoverTool(tooltips=[("epoch", "@index")] + [(f"loss_{column}", f"@{column}")],
                          mode='vline', renderers=[glyph]))

        def update_epoch(self: App, event):
            epoch = reduce(lambda a, b: a if a[0] < b[0] else b, [(abs(e - event.x), e) for e in self.state.epochs])[1]
            self.widget_epoch_select.value = str(epoch)

        fig.on_event(DoubleTap, partial(update_epoch, self))
        return fig

    # def update_experiment_id(self, attr, old, new):
    #     self.update_experiment()

    def on_last_run(self):
        id = self._sacred_utils.get_last_run().id
        self.widget_text_experiment_id.value = str(id)

    def run(self):

        self.widget_text_experiment_id = TextInput(value='170')
        self.widget_text_experiment_id.on_change('value', lambda a, o, n: self.update_experiment())
        self.widget_button_last_run = Button(label="Last run")
        self.widget_button_last_run.on_click(self.on_last_run)
        self.widget_text_experiment_id.on_change('value', lambda a, o, n: self.update_experiment())
        self.widget_epoch_select = Select(title='epoch', options=[])
        self.widget_button = Button(label="Read experiment")
        self.widget_button.on_click(self.update_experiment)
        self.widget_epoch_select.on_change('value', self.update_epoch)
        self.widget_pane = column()
        self.widget_loss_pane = column()
        self.widget_config_div = Div(text="")
        curdoc().add_root(
            column(
                row(
                    column(
                        row(Div(text="Experiment ID:"), self.widget_text_experiment_id, self.widget_button,
                            self.widget_button_last_run),
                        self.widget_loss_pane
                    ),
                    self.widget_config_div
                ),
                self.widget_epoch_select, self.widget_pane))
Exemplo n.º 41
0
def modify_doc(doc):
    # Bokeh renderers that hold current viz as its state
    hvplot = renderer.get_plot(dmap, doc)
    timeseriesPlot = renderer.get_plot(dmap_time_series, doc)

    def animate_update():
        year = slider.value + 1
        if year > end:
            year = start
        slider.value = year

    callback_id = None
    def animate():
        global callback_id
        if button.label == '► Play':
            button.label = '❚❚ Pause'
            callback_id = doc.add_periodic_callback(animate_update, 75)
        else:
            button.label = '► Play'
            doc.remove_periodic_callback(callback_id)

    def slider_update(attrname, old, new):
        # Notify the HoloViews stream of the slider update 
        year = 2000 + (new // 12)
        month = (new % 12) + 1
        stream.event(time_step=cftime.DatetimeNoLeap(year,month,1))
        slider.title = "{}-{}".format(year,month)
        
    def variable_update(event):
        global path, fram_data, curr_var, min_range, max_range, control_path, control_data, global_data, global_path, gv_geo_plot, curr_dataset, curr_intervention, DATA_DICT
        path = './iceClinic/data/f09_g16.B.cobalt.FRAM.MAY.{}.200005-208106.nc'.format(event.item)
        control_path = './iceClinic/data/f09_g16.B.cobalt.CONTROL.MAY.{}.200005-208106.nc'.format(event.item)
        global_path = './iceClinic/data/f09_g16.B.cobalt.GLOBAL.MAY.{}.200005-208106.nc'.format(event.item)
        curr_var = event.item
        fram_data = xr.open_dataset(path)
        control_data = xr.open_dataset(control_path)
        global_data = xr.open_dataset(global_path)
        DATA_DICT = {'CONTROL': control_data, 'FRAM' : fram_data, 'GLOBAL' : global_data}
        curr_dataset = DATA_DICT[curr_intervention]
       
        dataset = gv.Dataset(curr_dataset)
        stateBasemap = gv.Feature(feature.STATES)
        gv_geo_plot = dataset.to(gv.Image, ['lon', 'lat'], curr_var, dynamic=True).opts(title = '{} Intervention, {} data'.format(curr_intervention, curr_var), cmap=CMAP_DICT[curr_var], colorbar=True, backend='bokeh', projection = crs.PlateCarree()) *gf.coastline() * gf.borders() * stateBasemap.opts(fill_alpha=0,line_width=0.5)

        #control_min_range, control_max_range = getMinMax(control_data, curr_var)
        #print(control_min_range, control_max_range)
        fram_min_range, fram_max_range = getMinMax(fram_data, curr_var)
        global_min_range, global_max_range = getMinMax(global_data, curr_var)
        min_range = min(fram_min_range, global_min_range)
        max_range = max(fram_max_range,global_max_range)
        
        gv_geo_plot = gv_geo_plot.redim(**{curr_var:hv.Dimension(curr_var, range=(min_range, max_range))})
        var_stream.event(var=event.item)

    def lat_update(attr, old, new):
        if int(new) in range(-90,90):
            lat_stream.event(lat=int(new)) 

    def lon_update(attr, old, new):
        if int(new) in range(-180,180):
            new_lon = int(new) + 180
            lon_stream.event(lon=new_lon) 

    def intervention_update(event):
        global curr_var, DATA_DICT, control_data, curr_intervention, gv_geo_plot, min_range, max_range
        curr_intervention = event.item
        curr_ds = DATA_DICT[event.item]
        dataset = gv.Dataset(curr_ds)
        gv_geo_plot = dataset.to(gv.Image, ['lon', 'lat'], curr_var, dynamic=True).opts(title = '{} Intervention, {} data'.format(curr_intervention, curr_var), cmap=CMAP_DICT[curr_var], colorbar=True, backend='bokeh', projection = crs.PlateCarree()) *gf.coastline() * gf.borders() * stateBasemap.opts(fill_alpha=0,line_width=0.5)        
        gv_geo_plot = gv_geo_plot.redim(**{curr_var:hv.Dimension(curr_var, range=(min_range, max_range))})
        intervention_stream.event(intervention=event.item)

    #Time_slider
    #Note: It starts as 5 because the datasets start in June 2000, the fifth month with zero indexing
    start, end = 5, 900
    slider = Slider(start=start, end=end, value=start, step=1, title="Date", show_value=False)
    slider.on_change('value', slider_update)
    
    #Variable Dropdown
    menu = [("Temperature", "TS"), ("Precipitation", "PRECT"), ("Fire Weather", "FWI"), ("Precipitation Index", "SPI")]
    dropdown = Dropdown(label="Select Variable", button_type="primary", menu=menu)
    dropdown.on_click(variable_update)

    #Intervention Dropdown
    intervention_menu = [("Control", "CONTROL"), ("Fram", "FRAM"), ("Global", "GLOBAL")]
    intervention_dropdown = Dropdown(label="Select Intervention Type", button_type="primary", menu=intervention_menu)
    intervention_dropdown.on_click(intervention_update)

    #Lat Text Input
    lat_input = TextInput(value="45", title="Latitude:")
    lat_input.on_change("value", lat_update)

    #Lon Text Input
    lon_input = TextInput(value="122", title="Longitude:")
    lon_input.on_change("value", lon_update)

    #Slider Play Button
    button = Button(label='► Play', width=60)
    button.on_click(animate)
    

    #Code to generate the layout
    lat_lon_text = Div(text="<b>Note:</b> Latitude ranges from -90 to 90 and longitude from -180 to 180")
    spacer = Div(height=200)

    logo = figure(x_range=(0, 10), y_range=(0, 10), plot_width=300, plot_height=300)
    logo.image_url( url=['./iceClinic/static/logo.png'], x=0, y=0, w=10, h=10, anchor="bottom_left")
    logo.toolbar.logo, logo.toolbar_location = None, None
    logo.xaxis.visible, logo.yaxis.visible = None, None
    logo.xgrid.grid_line_color, logo.ygrid.grid_line_color = None, None
    # Combine the holoviews plot and widgets in a layout

    logo = row(logo, align='center')
    options_row = row(slider, button, align='center')
    left_plot_row= row(hvplot.state, align='center')
    left_column = column(left_plot_row, options_row, dropdown, intervention_dropdown, sizing_mode='stretch_width', align='center')
    coords_row = row(lat_input, lon_input, align='center')
    right_plot_row = row(timeseriesPlot.state, align='center')
    right_column = column(right_plot_row, coords_row, lat_lon_text)

    graphs = row(left_column, right_column, sizing_mode="stretch_width", align='center')

    plot = column(logo, graphs, spacer, sizing_mode='stretch_width', align='center')
    
    curdoc().add_root(plot)