Ejemplo n.º 1
0
 def modify_doc(doc):
     source = ColumnDataSource(dict(x=[1, 2], y=[1, 1], val=["a", "b"]))
     plot = Plot(plot_height=400, plot_width=400, x_range=Range1d(0, 1), y_range=Range1d(0, 1), min_border=0)
     plot.add_glyph(source, Circle(x='x', y='y', size=20))
     plot.add_tools(CustomAction(callback=CustomJS(args=dict(s=source), code=RECORD("data", "s.data"))))
     group = CheckboxGroup(labels=LABELS, css_classes=["foo"])
     def cb(active):
         source.data['val'] = (active + [0, 0])[:2] # keep col length at 2, padded with zero
     group.on_click(cb)
     doc.add_root(column(group, plot))
Ejemplo n.º 2
0
 def modify_doc(doc):
     source = ColumnDataSource(dict(x=[1, 2], y=[1, 1], val=["a", "b"]))
     plot = Plot(plot_height=400, plot_width=400, x_range=Range1d(0, 1), y_range=Range1d(0, 1), min_border=0)
     plot.add_glyph(source, Circle(x='x', y='y', size=20))
     plot.add_tools(CustomAction(callback=CustomJS(args=dict(s=source), code=RECORD("data", "s.data"))))
     group = CheckboxGroup(labels=LABELS, css_classes=["foo"])
     def cb(active):
         source.data['val'] = (active + [0, 0])[:2] # keep col length at 2, padded with zero
     group.on_click(cb)
     doc.add_root(column(group, plot))
Ejemplo n.º 3
0
from bokeh.io import output_file, show, curdoc
from bokeh.layouts import column, row
from bokeh.models import CheckboxGroup


def checkbox_handler(new):
    print('Checkbox button option ' + str(new) + 'selected.')


checkbox_group = CheckboxGroup(labels=['Option 1', 'Option 2', 'Option 3'],
                               active=[0])
checkbox_group.on_click(checkbox_handler)

controls = column(checkbox_group)

curdoc().add_root(row(controls))
Ejemplo n.º 4
0
class FitViewer(Component):
    name = 'FitViewer'

    def __init__(self, config, tool, **kwargs):
        """
        Parameters
        ----------
        config : traitlets.loader.Config
            Configuration specified by config file or cmdline arguments.
            Used to set traitlet values.
            Set to None if no configuration to pass.
        tool : ctapipe.core.Tool
            Tool executable that is calling this component.
            Passes the correct logger to the component.
            Set to None if no Tool to pass.
        kwargs
        """
        super().__init__(config=config, parent=tool, **kwargs)
        self._active_pixel = 0

        self.figure = None
        self.cdsource = None
        self.cdsource_f = None

        self.x = None
        self.stages = None
        self.neighbours2d = None
        self.fits = None
        self.fit_labels = None

        self.cb = None

        self.layout = None

    def create(self, subfit_labels):
        title = "Fit Viewer"
        fig = figure(title=title,
                     plot_width=400,
                     plot_height=400,
                     tools="",
                     toolbar_location=None,
                     outline_line_color='#595959')
        cdsource_d = dict(left=[], right=[], bottom=[], top=[])
        self.cdsource = ColumnDataSource(data=cdsource_d)
        fig.quad(bottom='bottom',
                 left='left',
                 right='right',
                 top='top',
                 source=self.cdsource,
                 alpha=0.5)

        cdsource_d_fit = dict(x=[], fit=[])
        self.fit_labels = ['fit']
        for subfit in subfit_labels:
            cdsource_d_fit[subfit] = []
        self.cdsource_f = ColumnDataSource(data=cdsource_d_fit)
        l1 = fig.line('x', 'fit', source=self.cdsource_f, color='yellow')
        self.fits = dict(fit=l1)
        for i, subfit in enumerate(subfit_labels):
            l = fig.line('x', subfit, source=self.cdsource_f, color='red')
            l.visible = False
            self.fits[subfit] = l

        self.fit_labels.extend(subfit_labels)
        self.cb = CheckboxGroup(labels=self.fit_labels, active=[0])
        self.cb.on_click(self._on_checkbox_select)

        self.layout = layout([[fig, self.cb]])

    def update(self, fitter):
        hist = fitter.hist
        edges = fitter.edges
        zeros = np.zeros(edges.size - 1)
        left = edges[:-1]
        right = edges[1:]
        cdsource_d = dict(left=left, right=right, bottom=zeros, top=hist)
        self.cdsource.data = cdsource_d

        cdsource_d_fit = dict(x=fitter.fit_x, fit=fitter.fit)
        for subfit, values in fitter.subfits.items():
            cdsource_d_fit[subfit] = values
        self.cdsource_f.data = cdsource_d_fit

    def _on_checkbox_select(self, active):
        self.active_fits = [self.fit_labels[i] for i in self.cb.active]
        for fit, line in self.fits.items():
            if fit in self.active_fits:
                line.visible = True
            else:
                line.visible = False
Ejemplo n.º 5
0
    def create_layout(self):
        # create figure
        self.x_range = Range1d(start=self.model.map_extent[0],
                               end=self.model.map_extent[2],
                               bounds=None)
        self.y_range = Range1d(start=self.model.map_extent[1],
                               end=self.model.map_extent[3],
                               bounds=None)

        self.fig = Figure(tools='wheel_zoom,pan',
                          x_range=self.x_range,
                          lod_threshold=None,
                          plot_width=self.model.plot_width,
                          plot_height=self.model.plot_height,
                          background_fill_color='black',
                          y_range=self.y_range)

        self.fig.min_border_top = 0
        self.fig.min_border_bottom = 10
        self.fig.min_border_left = 0
        self.fig.min_border_right = 0
        self.fig.axis.visible = False

        self.fig.xgrid.grid_line_color = None
        self.fig.ygrid.grid_line_color = None

        # add tiled basemap
        self.tile_source = WMTSTileSource(url=self.model.basemap)
        self.tile_renderer = TileRenderer(tile_source=self.tile_source)
        self.fig.renderers.append(self.tile_renderer)

        # add datashader layer
        self.image_source = ImageSource(
            url=self.model.service_url,
            extra_url_vars=self.model.shader_url_vars)
        self.image_renderer = DynamicImageRenderer(
            image_source=self.image_source)
        self.fig.renderers.append(self.image_renderer)

        # add label layer
        self.label_source = WMTSTileSource(url=self.model.labels_url)
        self.label_renderer = TileRenderer(tile_source=self.label_source)
        self.fig.renderers.append(self.label_renderer)

        # Add placeholder for legends (temporarily disabled)
        # self.model.legend_side_vbox = Column()
        # self.model.legend_bottom_vbox = Column()

        # add ui components
        controls = []
        axes_select = Select(name='Axes', options=list(self.model.axes.keys()))
        axes_select.on_change('value', self.on_axes_change)
        controls.append(axes_select)

        self.field_select = Select(name='Field',
                                   options=list(self.model.fields.keys()))
        self.field_select.on_change('value', self.on_field_change)
        controls.append(self.field_select)

        self.aggregate_select = Select(
            name='Aggregate',
            options=list(self.model.aggregate_functions.keys()))
        self.aggregate_select.on_change('value', self.on_aggregate_change)
        controls.append(self.aggregate_select)

        transfer_select = Select(name='Transfer Function',
                                 options=list(
                                     self.model.transfer_functions.keys()))
        transfer_select.on_change('value', self.on_transfer_function_change)
        controls.append(transfer_select)

        color_ramp_select = Select(name='Color Ramp',
                                   options=list(self.model.color_ramps.keys()))
        color_ramp_select.on_change('value', self.on_color_ramp_change)
        controls.append(color_ramp_select)

        spread_size_slider = Slider(title="Spread Size (px)",
                                    value=0,
                                    start=0,
                                    end=10,
                                    step=1)
        spread_size_slider.on_change('value', self.on_spread_size_change)
        controls.append(spread_size_slider)

        # hover (temporarily disabled)
        #hover_size_slider = Slider(title="Hover Size (px)", value=8, start=4,
        #                           end=30, step=1)
        #hover_size_slider.on_change('value', self.on_hover_size_change)
        #controls.append(hover_size_slider)

        # legends (temporarily disabled)
        # controls.append(self.model.legend_side_vbox)

        # add map components
        basemap_select = Select(name='Basemap',
                                value='Imagery',
                                options=list(self.model.basemaps.keys()))
        basemap_select.on_change('value', self.on_basemap_change)

        image_opacity_slider = Slider(title="Opacity",
                                      value=100,
                                      start=0,
                                      end=100,
                                      step=1)
        image_opacity_slider.on_change('value',
                                       self.on_image_opacity_slider_change)

        basemap_opacity_slider = Slider(title="Basemap Opacity",
                                        value=100,
                                        start=0,
                                        end=100,
                                        step=1)
        basemap_opacity_slider.on_change('value',
                                         self.on_basemap_opacity_slider_change)

        show_labels_chk = CheckboxGroup(labels=["Show Labels"], active=[0])
        show_labels_chk.on_click(self.on_labels_change)

        map_controls = [
            basemap_select, basemap_opacity_slider, image_opacity_slider,
            show_labels_chk
        ]

        self.controls = Column(height=600, children=controls)
        self.map_controls = Row(width=self.fig.plot_width,
                                children=map_controls)

        # legends (temporarily disabled)
        self.map_area = Column(width=900,
                               height=600,
                               children=[self.map_controls, self.fig])
        self.layout = Row(width=1300,
                          height=600,
                          children=[self.controls, self.map_area])
        self.model.fig = self.fig
        self.model.update_hover()
Ejemplo n.º 6
0
    for w in currBloodInputFields:
        w.value = '0'


def zoom(attrname):
    if zoomCheckbox.active:
        print('active')
        predictPlot_blood.x_range.start = -.5
        predictPlot_blood.x_range.end = 1.5
    else:
        print('inactive')
        predictPlot_blood.x_range.start = -.5
        predictPlot_blood.x_range.end = 26.5


zoomCheckbox.on_click(zoom)
clearButton.on_click(clearBloodFields)

# currWeightSlider.on_change('value', update_currWeightSlider)
# currWaterSlider.on_change('value', update_currWaterSlider)
# currAlcohol_freqeuncySelector.on_change('value', update_currAlcohol_freqeuncySelector)
# currDailyScreenTimeSlider.on_change('value', update_currDailyScreenTimeSlider)
# currActivityLevelSlider.on_change('value', update_currActivityLevelSlider)
# currHours_of_sleepSlider.on_change('value', update_currHours_of_sleepSlider)
# currCarboSlider.on_change('value', update_currCarboSlider)
# currFatSlider.on_change('value', update_currFatSlider)
# currProtnSlider.on_change('value', update_currProtnSlider)

selectorsAndSliders = [
    modelSelector, ageSlider, sexSelector, locationSelector,
    medicalConditionSelector, heightSlider, currWeightSlider, currWaterSlider,
def create():
    doc = curdoc()
    det_data = {}
    cami_meta = {}

    def proposal_textinput_callback(_attr, _old, new):
        nonlocal cami_meta
        proposal = new.strip()
        for zebra_proposals_path in pyzebra.ZEBRA_PROPOSALS_PATHS:
            proposal_path = os.path.join(zebra_proposals_path, proposal)
            if os.path.isdir(proposal_path):
                # found it
                break
        else:
            raise ValueError(f"Can not find data for proposal '{proposal}'.")

        file_list = []
        for file in os.listdir(proposal_path):
            if file.endswith(".hdf"):
                file_list.append((os.path.join(proposal_path, file), file))
        file_select.options = file_list

        cami_meta = {}

    proposal_textinput = TextInput(title="Proposal number:", width=210)
    proposal_textinput.on_change("value", proposal_textinput_callback)

    def upload_button_callback(_attr, _old, new):
        nonlocal cami_meta
        with io.StringIO(base64.b64decode(new).decode()) as file:
            cami_meta = pyzebra.parse_h5meta(file)
            file_list = cami_meta["filelist"]
            file_select.options = [(entry, os.path.basename(entry))
                                   for entry in file_list]

    upload_div = Div(text="or upload .cami file:", margin=(5, 5, 0, 5))
    upload_button = FileInput(accept=".cami", width=200)
    upload_button.on_change("value", upload_button_callback)

    def update_image(index=None):
        if index is None:
            index = index_spinner.value

        current_image = det_data["data"][index]
        proj_v_line_source.data.update(x=np.arange(0, IMAGE_W) + 0.5,
                                       y=np.mean(current_image, axis=0))
        proj_h_line_source.data.update(x=np.mean(current_image, axis=1),
                                       y=np.arange(0, IMAGE_H) + 0.5)

        image_source.data.update(
            h=[np.zeros((1, 1))],
            k=[np.zeros((1, 1))],
            l=[np.zeros((1, 1))],
        )
        image_source.data.update(image=[current_image])

        if main_auto_checkbox.active:
            im_min = np.min(current_image)
            im_max = np.max(current_image)

            display_min_spinner.value = im_min
            display_max_spinner.value = im_max

            image_glyph.color_mapper.low = im_min
            image_glyph.color_mapper.high = im_max

        if "mf" in det_data:
            metadata_table_source.data.update(mf=[det_data["mf"][index]])
        else:
            metadata_table_source.data.update(mf=[None])

        if "temp" in det_data:
            metadata_table_source.data.update(temp=[det_data["temp"][index]])
        else:
            metadata_table_source.data.update(temp=[None])

        gamma, nu = calculate_pol(det_data, index)
        omega = np.ones((IMAGE_H, IMAGE_W)) * det_data["omega"][index]
        image_source.data.update(gamma=[gamma], nu=[nu], omega=[omega])

    def update_overview_plot():
        h5_data = det_data["data"]
        n_im, n_y, n_x = h5_data.shape
        overview_x = np.mean(h5_data, axis=1)
        overview_y = np.mean(h5_data, axis=2)

        overview_plot_x_image_source.data.update(image=[overview_x],
                                                 dw=[n_x],
                                                 dh=[n_im])
        overview_plot_y_image_source.data.update(image=[overview_y],
                                                 dw=[n_y],
                                                 dh=[n_im])

        if proj_auto_checkbox.active:
            im_min = min(np.min(overview_x), np.min(overview_y))
            im_max = max(np.max(overview_x), np.max(overview_y))

            proj_display_min_spinner.value = im_min
            proj_display_max_spinner.value = im_max

            overview_plot_x_image_glyph.color_mapper.low = im_min
            overview_plot_y_image_glyph.color_mapper.low = im_min
            overview_plot_x_image_glyph.color_mapper.high = im_max
            overview_plot_y_image_glyph.color_mapper.high = im_max

        frame_range.start = 0
        frame_range.end = n_im
        frame_range.reset_start = 0
        frame_range.reset_end = n_im
        frame_range.bounds = (0, n_im)

        scan_motor = det_data["scan_motor"]
        overview_plot_y.axis[1].axis_label = f"Scanning motor, {scan_motor}"

        var = det_data[scan_motor]
        var_start = var[0]
        var_end = var[-1] + (var[-1] - var[0]) / (n_im - 1)

        scanning_motor_range.start = var_start
        scanning_motor_range.end = var_end
        scanning_motor_range.reset_start = var_start
        scanning_motor_range.reset_end = var_end
        # handle both, ascending and descending sequences
        scanning_motor_range.bounds = (min(var_start,
                                           var_end), max(var_start, var_end))

    def file_select_callback(_attr, old, new):
        nonlocal det_data
        if not new:
            # skip empty selections
            return

        # Avoid selection of multiple indicies (via Shift+Click or Ctrl+Click)
        if len(new) > 1:
            # drop selection to the previous one
            file_select.value = old
            return

        if len(old) > 1:
            # skip unnecessary update caused by selection drop
            return

        det_data = pyzebra.read_detector_data(new[0])

        if cami_meta and "crystal" in cami_meta:
            det_data["ub"] = cami_meta["crystal"]["UB"]

        index_spinner.value = 0
        index_spinner.high = det_data["data"].shape[0] - 1
        index_slider.end = det_data["data"].shape[0] - 1

        zebra_mode = det_data["zebra_mode"]
        if zebra_mode == "nb":
            metadata_table_source.data.update(geom=["normal beam"])
        else:  # zebra_mode == "bi"
            metadata_table_source.data.update(geom=["bisecting"])

        update_image(0)
        update_overview_plot()

    file_select = MultiSelect(title="Available .hdf files:",
                              width=210,
                              height=250)
    file_select.on_change("value", file_select_callback)

    def index_callback(_attr, _old, new):
        update_image(new)

    index_slider = Slider(value=0, start=0, end=1, show_value=False, width=400)

    index_spinner = Spinner(title="Image index:", value=0, low=0, width=100)
    index_spinner.on_change("value", index_callback)

    index_slider.js_link("value_throttled", index_spinner, "value")
    index_spinner.js_link("value", index_slider, "value")

    plot = Plot(
        x_range=Range1d(0, IMAGE_W, bounds=(0, IMAGE_W)),
        y_range=Range1d(0, IMAGE_H, bounds=(0, IMAGE_H)),
        plot_height=IMAGE_PLOT_H,
        plot_width=IMAGE_PLOT_W,
        toolbar_location="left",
    )

    # ---- tools
    plot.toolbar.logo = None

    # ---- axes
    plot.add_layout(LinearAxis(), place="above")
    plot.add_layout(LinearAxis(major_label_orientation="vertical"),
                    place="right")

    # ---- grid lines
    plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- rgba image glyph
    image_source = ColumnDataSource(
        dict(
            image=[np.zeros((IMAGE_H, IMAGE_W), dtype="float32")],
            h=[np.zeros((1, 1))],
            k=[np.zeros((1, 1))],
            l=[np.zeros((1, 1))],
            gamma=[np.zeros((1, 1))],
            nu=[np.zeros((1, 1))],
            omega=[np.zeros((1, 1))],
            x=[0],
            y=[0],
            dw=[IMAGE_W],
            dh=[IMAGE_H],
        ))

    h_glyph = Image(image="h", x="x", y="y", dw="dw", dh="dh", global_alpha=0)
    k_glyph = Image(image="k", x="x", y="y", dw="dw", dh="dh", global_alpha=0)
    l_glyph = Image(image="l", x="x", y="y", dw="dw", dh="dh", global_alpha=0)
    gamma_glyph = Image(image="gamma",
                        x="x",
                        y="y",
                        dw="dw",
                        dh="dh",
                        global_alpha=0)
    nu_glyph = Image(image="nu",
                     x="x",
                     y="y",
                     dw="dw",
                     dh="dh",
                     global_alpha=0)
    omega_glyph = Image(image="omega",
                        x="x",
                        y="y",
                        dw="dw",
                        dh="dh",
                        global_alpha=0)

    plot.add_glyph(image_source, h_glyph)
    plot.add_glyph(image_source, k_glyph)
    plot.add_glyph(image_source, l_glyph)
    plot.add_glyph(image_source, gamma_glyph)
    plot.add_glyph(image_source, nu_glyph)
    plot.add_glyph(image_source, omega_glyph)

    image_glyph = Image(image="image", x="x", y="y", dw="dw", dh="dh")
    plot.add_glyph(image_source, image_glyph, name="image_glyph")

    # ---- projections
    proj_v = Plot(
        x_range=plot.x_range,
        y_range=DataRange1d(),
        plot_height=150,
        plot_width=IMAGE_PLOT_W,
        toolbar_location=None,
    )

    proj_v.add_layout(LinearAxis(major_label_orientation="vertical"),
                      place="right")
    proj_v.add_layout(LinearAxis(major_label_text_font_size="0pt"),
                      place="below")

    proj_v.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    proj_v.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    proj_v_line_source = ColumnDataSource(dict(x=[], y=[]))
    proj_v.add_glyph(proj_v_line_source,
                     Line(x="x", y="y", line_color="steelblue"))

    proj_h = Plot(
        x_range=DataRange1d(),
        y_range=plot.y_range,
        plot_height=IMAGE_PLOT_H,
        plot_width=150,
        toolbar_location=None,
    )

    proj_h.add_layout(LinearAxis(), place="above")
    proj_h.add_layout(LinearAxis(major_label_text_font_size="0pt"),
                      place="left")

    proj_h.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    proj_h.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    proj_h_line_source = ColumnDataSource(dict(x=[], y=[]))
    proj_h.add_glyph(proj_h_line_source,
                     Line(x="x", y="y", line_color="steelblue"))

    # add tools
    hovertool = HoverTool(tooltips=[
        ("intensity", "@image"),
        ("gamma", "@gamma"),
        ("nu", "@nu"),
        ("omega", "@omega"),
        ("h", "@h"),
        ("k", "@k"),
        ("l", "@l"),
    ])

    box_edit_source = ColumnDataSource(dict(x=[], y=[], width=[], height=[]))
    box_edit_glyph = Rect(x="x",
                          y="y",
                          width="width",
                          height="height",
                          fill_alpha=0,
                          line_color="red")
    box_edit_renderer = plot.add_glyph(box_edit_source, box_edit_glyph)
    boxedittool = BoxEditTool(renderers=[box_edit_renderer], num_objects=1)

    def box_edit_callback(_attr, _old, new):
        if new["x"]:
            h5_data = det_data["data"]
            x_val = np.arange(h5_data.shape[0])
            left = int(np.floor(new["x"][0]))
            right = int(np.ceil(new["x"][0] + new["width"][0]))
            bottom = int(np.floor(new["y"][0]))
            top = int(np.ceil(new["y"][0] + new["height"][0]))
            y_val = np.sum(h5_data[:, bottom:top, left:right], axis=(1, 2))
        else:
            x_val = []
            y_val = []

        roi_avg_plot_line_source.data.update(x=x_val, y=y_val)

    box_edit_source.on_change("data", box_edit_callback)

    wheelzoomtool = WheelZoomTool(maintain_focus=False)
    plot.add_tools(
        PanTool(),
        BoxZoomTool(),
        wheelzoomtool,
        ResetTool(),
        hovertool,
        boxedittool,
    )
    plot.toolbar.active_scroll = wheelzoomtool

    # shared frame ranges
    frame_range = Range1d(0, 1, bounds=(0, 1))
    scanning_motor_range = Range1d(0, 1, bounds=(0, 1))

    det_x_range = Range1d(0, IMAGE_W, bounds=(0, IMAGE_W))
    overview_plot_x = Plot(
        title=Title(text="Projections on X-axis"),
        x_range=det_x_range,
        y_range=frame_range,
        extra_y_ranges={"scanning_motor": scanning_motor_range},
        plot_height=400,
        plot_width=IMAGE_PLOT_W - 3,
    )

    # ---- tools
    wheelzoomtool = WheelZoomTool(maintain_focus=False)
    overview_plot_x.toolbar.logo = None
    overview_plot_x.add_tools(
        PanTool(),
        BoxZoomTool(),
        wheelzoomtool,
        ResetTool(),
    )
    overview_plot_x.toolbar.active_scroll = wheelzoomtool

    # ---- axes
    overview_plot_x.add_layout(LinearAxis(axis_label="Coordinate X, pix"),
                               place="below")
    overview_plot_x.add_layout(LinearAxis(axis_label="Frame",
                                          major_label_orientation="vertical"),
                               place="left")

    # ---- grid lines
    overview_plot_x.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    overview_plot_x.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- rgba image glyph
    overview_plot_x_image_source = ColumnDataSource(
        dict(image=[np.zeros((1, 1), dtype="float32")],
             x=[0],
             y=[0],
             dw=[IMAGE_W],
             dh=[1]))

    overview_plot_x_image_glyph = Image(image="image",
                                        x="x",
                                        y="y",
                                        dw="dw",
                                        dh="dh")
    overview_plot_x.add_glyph(overview_plot_x_image_source,
                              overview_plot_x_image_glyph,
                              name="image_glyph")

    det_y_range = Range1d(0, IMAGE_H, bounds=(0, IMAGE_H))
    overview_plot_y = Plot(
        title=Title(text="Projections on Y-axis"),
        x_range=det_y_range,
        y_range=frame_range,
        extra_y_ranges={"scanning_motor": scanning_motor_range},
        plot_height=400,
        plot_width=IMAGE_PLOT_H + 22,
    )

    # ---- tools
    wheelzoomtool = WheelZoomTool(maintain_focus=False)
    overview_plot_y.toolbar.logo = None
    overview_plot_y.add_tools(
        PanTool(),
        BoxZoomTool(),
        wheelzoomtool,
        ResetTool(),
    )
    overview_plot_y.toolbar.active_scroll = wheelzoomtool

    # ---- axes
    overview_plot_y.add_layout(LinearAxis(axis_label="Coordinate Y, pix"),
                               place="below")
    overview_plot_y.add_layout(
        LinearAxis(
            y_range_name="scanning_motor",
            axis_label="Scanning motor",
            major_label_orientation="vertical",
        ),
        place="right",
    )

    # ---- grid lines
    overview_plot_y.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    overview_plot_y.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- rgba image glyph
    overview_plot_y_image_source = ColumnDataSource(
        dict(image=[np.zeros((1, 1), dtype="float32")],
             x=[0],
             y=[0],
             dw=[IMAGE_H],
             dh=[1]))

    overview_plot_y_image_glyph = Image(image="image",
                                        x="x",
                                        y="y",
                                        dw="dw",
                                        dh="dh")
    overview_plot_y.add_glyph(overview_plot_y_image_source,
                              overview_plot_y_image_glyph,
                              name="image_glyph")

    roi_avg_plot = Plot(
        x_range=DataRange1d(),
        y_range=DataRange1d(),
        plot_height=150,
        plot_width=IMAGE_PLOT_W,
        toolbar_location="left",
    )

    # ---- tools
    roi_avg_plot.toolbar.logo = None

    # ---- axes
    roi_avg_plot.add_layout(LinearAxis(), place="below")
    roi_avg_plot.add_layout(LinearAxis(major_label_orientation="vertical"),
                            place="left")

    # ---- grid lines
    roi_avg_plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    roi_avg_plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    roi_avg_plot_line_source = ColumnDataSource(dict(x=[], y=[]))
    roi_avg_plot.add_glyph(roi_avg_plot_line_source,
                           Line(x="x", y="y", line_color="steelblue"))

    cmap_dict = {
        "gray": Greys256,
        "gray_reversed": Greys256[::-1],
        "plasma": Plasma256,
        "cividis": Cividis256,
    }

    def colormap_callback(_attr, _old, new):
        image_glyph.color_mapper = LinearColorMapper(palette=cmap_dict[new])
        overview_plot_x_image_glyph.color_mapper = LinearColorMapper(
            palette=cmap_dict[new])
        overview_plot_y_image_glyph.color_mapper = LinearColorMapper(
            palette=cmap_dict[new])

    colormap = Select(title="Colormap:",
                      options=list(cmap_dict.keys()),
                      width=210)
    colormap.on_change("value", colormap_callback)
    colormap.value = "plasma"

    STEP = 1

    def main_auto_checkbox_callback(state):
        if state:
            display_min_spinner.disabled = True
            display_max_spinner.disabled = True
        else:
            display_min_spinner.disabled = False
            display_max_spinner.disabled = False

        update_image()

    main_auto_checkbox = CheckboxGroup(labels=["Main Auto Range"],
                                       active=[0],
                                       width=145,
                                       margin=[10, 5, 0, 5])
    main_auto_checkbox.on_click(main_auto_checkbox_callback)

    def display_max_spinner_callback(_attr, _old_value, new_value):
        display_min_spinner.high = new_value - STEP
        image_glyph.color_mapper.high = new_value

    display_max_spinner = Spinner(
        low=0 + STEP,
        value=1,
        step=STEP,
        disabled=bool(main_auto_checkbox.active),
        width=100,
        height=31,
    )
    display_max_spinner.on_change("value", display_max_spinner_callback)

    def display_min_spinner_callback(_attr, _old_value, new_value):
        display_max_spinner.low = new_value + STEP
        image_glyph.color_mapper.low = new_value

    display_min_spinner = Spinner(
        low=0,
        high=1 - STEP,
        value=0,
        step=STEP,
        disabled=bool(main_auto_checkbox.active),
        width=100,
        height=31,
    )
    display_min_spinner.on_change("value", display_min_spinner_callback)

    PROJ_STEP = 0.1

    def proj_auto_checkbox_callback(state):
        if state:
            proj_display_min_spinner.disabled = True
            proj_display_max_spinner.disabled = True
        else:
            proj_display_min_spinner.disabled = False
            proj_display_max_spinner.disabled = False

        update_overview_plot()

    proj_auto_checkbox = CheckboxGroup(labels=["Projections Auto Range"],
                                       active=[0],
                                       width=145,
                                       margin=[10, 5, 0, 5])
    proj_auto_checkbox.on_click(proj_auto_checkbox_callback)

    def proj_display_max_spinner_callback(_attr, _old_value, new_value):
        proj_display_min_spinner.high = new_value - PROJ_STEP
        overview_plot_x_image_glyph.color_mapper.high = new_value
        overview_plot_y_image_glyph.color_mapper.high = new_value

    proj_display_max_spinner = Spinner(
        low=0 + PROJ_STEP,
        value=1,
        step=PROJ_STEP,
        disabled=bool(proj_auto_checkbox.active),
        width=100,
        height=31,
    )
    proj_display_max_spinner.on_change("value",
                                       proj_display_max_spinner_callback)

    def proj_display_min_spinner_callback(_attr, _old_value, new_value):
        proj_display_max_spinner.low = new_value + PROJ_STEP
        overview_plot_x_image_glyph.color_mapper.low = new_value
        overview_plot_y_image_glyph.color_mapper.low = new_value

    proj_display_min_spinner = Spinner(
        low=0,
        high=1 - PROJ_STEP,
        value=0,
        step=PROJ_STEP,
        disabled=bool(proj_auto_checkbox.active),
        width=100,
        height=31,
    )
    proj_display_min_spinner.on_change("value",
                                       proj_display_min_spinner_callback)

    def hkl_button_callback():
        index = index_spinner.value
        h, k, l = calculate_hkl(det_data, index)
        image_source.data.update(h=[h], k=[k], l=[l])

    hkl_button = Button(label="Calculate hkl (slow)", width=210)
    hkl_button.on_click(hkl_button_callback)

    def events_list_callback(_attr, _old, new):
        doc.events_list_spind.value = new

    events_list = TextAreaInput(rows=7, width=830)
    events_list.on_change("value", events_list_callback)
    doc.events_list_hdf_viewer = events_list

    def add_event_button_callback():
        diff_vec = []
        p0 = [1.0, 0.0, 1.0]
        maxfev = 100000

        wave = det_data["wave"]
        ddist = det_data["ddist"]

        gamma = det_data["gamma"][0]
        omega = det_data["omega"][0]
        nu = det_data["nu"][0]
        chi = det_data["chi"][0]
        phi = det_data["phi"][0]

        scan_motor = det_data["scan_motor"]
        var_angle = det_data[scan_motor]

        x0 = int(np.floor(det_x_range.start))
        xN = int(np.ceil(det_x_range.end))
        y0 = int(np.floor(det_y_range.start))
        yN = int(np.ceil(det_y_range.end))
        fr0 = int(np.floor(frame_range.start))
        frN = int(np.ceil(frame_range.end))
        data_roi = det_data["data"][fr0:frN, y0:yN, x0:xN]

        cnts = np.sum(data_roi, axis=(1, 2))
        coeff, _ = curve_fit(gauss,
                             range(len(cnts)),
                             cnts,
                             p0=p0,
                             maxfev=maxfev)

        m = cnts.mean()
        sd = cnts.std()
        snr_cnts = np.where(sd == 0, 0, m / sd)

        frC = fr0 + coeff[1]
        var_F = var_angle[math.floor(frC)]
        var_C = var_angle[math.ceil(frC)]
        frStep = frC - math.floor(frC)
        var_step = var_C - var_F
        var_p = var_F + var_step * frStep

        if scan_motor == "gamma":
            gamma = var_p
        elif scan_motor == "omega":
            omega = var_p
        elif scan_motor == "nu":
            nu = var_p
        elif scan_motor == "chi":
            chi = var_p
        elif scan_motor == "phi":
            phi = var_p

        intensity = coeff[1] * abs(
            coeff[2] * var_step) * math.sqrt(2) * math.sqrt(np.pi)

        projX = np.sum(data_roi, axis=(0, 1))
        coeff, _ = curve_fit(gauss,
                             range(len(projX)),
                             projX,
                             p0=p0,
                             maxfev=maxfev)
        x_pos = x0 + coeff[1]

        projY = np.sum(data_roi, axis=(0, 2))
        coeff, _ = curve_fit(gauss,
                             range(len(projY)),
                             projY,
                             p0=p0,
                             maxfev=maxfev)
        y_pos = y0 + coeff[1]

        ga, nu = pyzebra.det2pol(ddist, gamma, nu, x_pos, y_pos)
        diff_vector = pyzebra.z1frmd(wave, ga, omega, chi, phi, nu)
        d_spacing = float(pyzebra.dandth(wave, diff_vector)[0])
        diff_vector = diff_vector.flatten() * 1e10
        dv1, dv2, dv3 = diff_vector

        diff_vec.append(diff_vector)

        if events_list.value and not events_list.value.endswith("\n"):
            events_list.value = events_list.value + "\n"

        events_list.value = (
            events_list.value +
            f"{x_pos} {y_pos} {intensity} {snr_cnts} {dv1} {dv2} {dv3} {d_spacing}"
        )

    add_event_button = Button(label="Add spind event")
    add_event_button.on_click(add_event_button_callback)

    metadata_table_source = ColumnDataSource(
        dict(geom=[""], temp=[None], mf=[None]))
    num_formatter = NumberFormatter(format="0.00", nan_format="")
    metadata_table = DataTable(
        source=metadata_table_source,
        columns=[
            TableColumn(field="geom", title="Geometry", width=100),
            TableColumn(field="temp",
                        title="Temperature",
                        formatter=num_formatter,
                        width=100),
            TableColumn(field="mf",
                        title="Magnetic Field",
                        formatter=num_formatter,
                        width=100),
        ],
        width=300,
        height=50,
        autosize_mode="none",
        index_position=None,
    )

    # Final layout
    import_layout = column(proposal_textinput, upload_div, upload_button,
                           file_select)
    layout_image = column(
        gridplot([[proj_v, None], [plot, proj_h]], merge_tools=False))
    colormap_layout = column(
        colormap,
        main_auto_checkbox,
        row(display_min_spinner, display_max_spinner),
        proj_auto_checkbox,
        row(proj_display_min_spinner, proj_display_max_spinner),
    )

    layout_controls = column(
        row(metadata_table, index_spinner,
            column(Spacer(height=25), index_slider)),
        row(add_event_button, hkl_button),
        row(events_list),
    )

    layout_overview = column(
        gridplot(
            [[overview_plot_x, overview_plot_y]],
            toolbar_options=dict(logo=None),
            merge_tools=True,
            toolbar_location="left",
        ), )

    tab_layout = row(
        column(import_layout, colormap_layout),
        column(layout_overview, layout_controls),
        column(roi_avg_plot, layout_image),
    )

    return Panel(child=tab_layout, title="hdf viewer")
Ejemplo n.º 8
0
    def __init__(self,
                 nplots,
                 plot_height=350,
                 plot_width=700,
                 lower=0,
                 upper=1000,
                 nbins=100):
        """Initialize histogram plots.

        Args:
            nplots (int): Number of histogram plots that will share common controls.
            plot_height (int, optional): Height of plot area in screen pixels. Defaults to 350.
            plot_width (int, optional): Width of plot area in screen pixels. Defaults to 700.
            lower (int, optional): Initial lower range of the bins. Defaults to 0.
            upper (int, optional): Initial upper range of the bins. Defaults to 1000.
            nbins (int, optional): Initial number of the bins. Defaults to 100.
        """
        # Histogram plots
        self.plots = []
        self._plot_sources = []
        for ind in range(nplots):
            plot = Plot(
                x_range=DataRange1d(),
                y_range=DataRange1d(),
                plot_height=plot_height,
                plot_width=plot_width,
                toolbar_location="left",
            )

            # ---- tools
            plot.toolbar.logo = None
            # share 'pan', 'boxzoom', and 'wheelzoom' tools between all plots
            if ind == 0:
                pantool = PanTool()
                boxzoomtool = BoxZoomTool()
                wheelzoomtool = WheelZoomTool()
            plot.add_tools(pantool, boxzoomtool, wheelzoomtool, SaveTool(),
                           ResetTool())

            # ---- axes
            plot.add_layout(LinearAxis(), place="below")
            plot.add_layout(LinearAxis(major_label_orientation="vertical"),
                            place="left")

            # ---- grid lines
            plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
            plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

            # ---- quad (single bin) glyph
            plot_source = ColumnDataSource(dict(left=[], right=[], top=[]))
            plot.add_glyph(
                plot_source,
                Quad(left="left",
                     right="right",
                     top="top",
                     bottom=0,
                     fill_color="steelblue"),
            )

            self.plots.append(plot)
            self._plot_sources.append(plot_source)

        self._counts = []
        self._empty_counts()

        # Histogram controls
        # ---- histogram range toggle button
        def auto_toggle_callback(state):
            if state:  # Automatic
                lower_spinner.disabled = True
                upper_spinner.disabled = True

            else:  # Manual
                lower_spinner.disabled = False
                upper_spinner.disabled = False

        auto_toggle = CheckboxGroup(labels=["Auto Hist Range"],
                                    active=[0],
                                    default_size=145)
        auto_toggle.on_click(auto_toggle_callback)
        self.auto_toggle = auto_toggle

        # ---- histogram lower range
        def lower_spinner_callback(_attr, _old_value, new_value):
            self.upper_spinner.low = new_value + STEP
            self._empty_counts()

        lower_spinner = Spinner(
            title="Lower Range:",
            high=upper - STEP,
            value=lower,
            step=STEP,
            disabled=bool(auto_toggle.active),
            default_size=145,
        )
        lower_spinner.on_change("value", lower_spinner_callback)
        self.lower_spinner = lower_spinner

        # ---- histogram upper range
        def upper_spinner_callback(_attr, _old_value, new_value):
            self.lower_spinner.high = new_value - STEP
            self._empty_counts()

        upper_spinner = Spinner(
            title="Upper Range:",
            low=lower + STEP,
            value=upper,
            step=STEP,
            disabled=bool(auto_toggle.active),
            default_size=145,
        )
        upper_spinner.on_change("value", upper_spinner_callback)
        self.upper_spinner = upper_spinner

        # ---- histogram number of bins
        def nbins_spinner_callback(_attr, _old_value, _new_value):
            self._empty_counts()

        nbins_spinner = Spinner(title="Number of Bins:",
                                low=1,
                                value=nbins,
                                default_size=145)
        nbins_spinner.on_change("value", nbins_spinner_callback)
        self.nbins_spinner = nbins_spinner

        # ---- histogram log10 of counts toggle button
        def log10counts_toggle_callback(state):
            self._empty_counts()
            for plot in self.plots:
                if state:
                    plot.yaxis[0].axis_label = "log⏨(Counts)"
                else:
                    plot.yaxis[0].axis_label = "Counts"

        log10counts_toggle = CheckboxGroup(labels=["log⏨(Counts)"],
                                           default_size=145)
        log10counts_toggle.on_click(log10counts_toggle_callback)
        self.log10counts_toggle = log10counts_toggle
Ejemplo n.º 9
0
    def create_layout(self):

        # create figure
        self.x_range = Range1d(start=self.model.map_extent[0], end=self.model.map_extent[2], bounds=None)
        self.y_range = Range1d(start=self.model.map_extent[1], end=self.model.map_extent[3], bounds=None)

        self.fig = Figure(
            tools="wheel_zoom,pan",
            x_range=self.x_range,
            lod_threshold=None,
            plot_width=self.model.plot_width,
            plot_height=self.model.plot_height,
            background_fill_color="black",
            y_range=self.y_range,
        )

        self.fig.min_border_top = 0
        self.fig.min_border_bottom = 10
        self.fig.min_border_left = 0
        self.fig.min_border_right = 0
        self.fig.axis.visible = False

        self.fig.xgrid.grid_line_color = None
        self.fig.ygrid.grid_line_color = None

        # add tiled basemap
        self.tile_source = WMTSTileSource(url=self.model.basemap)
        self.tile_renderer = TileRenderer(tile_source=self.tile_source)
        self.fig.renderers.append(self.tile_renderer)

        # add datashader layer
        self.image_source = ImageSource(url=self.model.service_url, extra_url_vars=self.model.shader_url_vars)
        self.image_renderer = DynamicImageRenderer(image_source=self.image_source)
        self.fig.renderers.append(self.image_renderer)

        # add label layer
        self.label_source = WMTSTileSource(url=self.model.labels_url)
        self.label_renderer = TileRenderer(tile_source=self.label_source)
        self.fig.renderers.append(self.label_renderer)

        # Add placeholder for legends (temporarily disabled)
        # self.model.legend_side_vbox = VBox()
        # self.model.legend_bottom_vbox = VBox()

        # add ui components
        controls = []
        axes_select = Select.create(name="Axes", options=self.model.axes)
        axes_select.on_change("value", self.on_axes_change)
        controls.append(axes_select)

        self.field_select = Select.create(name="Field", options=self.model.fields)
        self.field_select.on_change("value", self.on_field_change)
        controls.append(self.field_select)

        self.aggregate_select = Select.create(name="Aggregate", options=self.model.aggregate_functions)
        self.aggregate_select.on_change("value", self.on_aggregate_change)
        controls.append(self.aggregate_select)

        transfer_select = Select.create(name="Transfer Function", options=self.model.transfer_functions)
        transfer_select.on_change("value", self.on_transfer_function_change)
        controls.append(transfer_select)

        color_ramp_select = Select.create(name="Color Ramp", options=self.model.color_ramps)
        color_ramp_select.on_change("value", self.on_color_ramp_change)
        controls.append(color_ramp_select)

        spread_size_slider = Slider(title="Spread Size (px)", value=0, start=0, end=10, step=1)
        spread_size_slider.on_change("value", self.on_spread_size_change)
        controls.append(spread_size_slider)

        hover_size_slider = Slider(title="Hover Size (px)", value=8, start=4, end=30, step=1)
        hover_size_slider.on_change("value", self.on_hover_size_change)
        controls.append(hover_size_slider)

        # legends (temporarily disabled)
        # controls.append(self.model.legend_side_vbox)

        # add map components
        basemap_select = Select.create(name="Basemap", value="Imagery", options=self.model.basemaps)
        basemap_select.on_change("value", self.on_basemap_change)

        image_opacity_slider = Slider(title="Opacity", value=100, start=0, end=100, step=1)
        image_opacity_slider.on_change("value", self.on_image_opacity_slider_change)

        basemap_opacity_slider = Slider(title="Basemap Opacity", value=100, start=0, end=100, step=1)
        basemap_opacity_slider.on_change("value", self.on_basemap_opacity_slider_change)

        show_labels_chk = CheckboxGroup(labels=["Show Labels"], active=[0])
        show_labels_chk.on_click(self.on_labels_change)

        map_controls = [basemap_select, basemap_opacity_slider, image_opacity_slider, show_labels_chk]

        self.controls = VBox(height=600, children=controls)
        self.map_controls = HBox(width=self.fig.plot_width, children=map_controls)

        # legends (temporarily disabled)
        self.map_area = VBox(width=900, height=600, children=[self.map_controls, self.fig])
        self.layout = HBox(width=1300, height=600, children=[self.controls, self.map_area])
        self.model.fig = self.fig
        self.model.update_hover()
    def create_layout(self):

        # create figure
        self.x_range = Range1d(start=self.model.map_extent[0],
                               end=self.model.map_extent[2],
                               bounds=None)
        self.y_range = Range1d(start=self.model.map_extent[1],
                               end=self.model.map_extent[3],
                               bounds=None)

        self.fig = Figure(tools='wheel_zoom,pan',
                          x_range=self.x_range,
                          lod_threshold=None,
                          plot_width=self.model.plot_width,
                          plot_height=self.model.plot_height,
                          y_range=self.y_range)

        self.fig.min_border_top = 0
        self.fig.min_border_bottom = 10
        self.fig.min_border_left = 0
        self.fig.min_border_right = 0
        self.fig.axis.visible = False

        self.fig.xgrid.grid_line_color = None
        self.fig.ygrid.grid_line_color = None

        # add tiled basemap
        self.tile_source = WMTSTileSource(url=self.model.basemap)
        self.tile_renderer = TileRenderer(tile_source=self.tile_source)
        self.fig.renderers.append(self.tile_renderer)

        # add datashader layer
        self.image_source = ImageSource(
            url=self.model.service_url,
            extra_url_vars=self.model.shader_url_vars)
        self.image_renderer = DynamicImageRenderer(
            image_source=self.image_source)
        self.fig.renderers.append(self.image_renderer)

        # add label layer
        self.label_source = WMTSTileSource(url=self.model.labels_url)
        self.label_renderer = TileRenderer(tile_source=self.label_source)
        self.fig.renderers.append(self.label_renderer)

        # Add a hover tool
        hover_layer = HoverLayer()
        hover_layer.field_name = self.model.field_title
        hover_layer.is_categorical = self.model.field in self.model.categorical_fields
        self.fig.renderers.append(hover_layer.renderer)
        self.fig.add_tools(hover_layer.tool)
        self.model.hover_layer = hover_layer

        self.model.legend_side_vbox = VBox()
        self.model.legend_bottom_vbox = VBox()

        # add ui components
        controls = []
        axes_select = Select.create(name='Axes', options=self.model.axes)
        axes_select.on_change('value', self.on_axes_change)
        controls.append(axes_select)

        self.field_select = Select.create(name='Field',
                                          options=self.model.fields)
        self.field_select.on_change('value', self.on_field_change)
        controls.append(self.field_select)

        self.aggregate_select = Select.create(
            name='Aggregate', options=self.model.aggregate_functions)
        self.aggregate_select.on_change('value', self.on_aggregate_change)
        controls.append(self.aggregate_select)

        transfer_select = Select.create(name='Transfer Function',
                                        options=self.model.transfer_functions)
        transfer_select.on_change('value', self.on_transfer_function_change)
        controls.append(transfer_select)

        color_ramp_select = Select.create(name='Color Ramp',
                                          options=self.model.color_ramps)
        color_ramp_select.on_change('value', self.on_color_ramp_change)
        controls.append(color_ramp_select)

        spread_size_slider = Slider(title="Spread Size (px)",
                                    value=0,
                                    start=0,
                                    end=10,
                                    step=1)
        spread_size_slider.on_change('value', self.on_spread_size_change)
        controls.append(spread_size_slider)

        hover_size_slider = Slider(title="Hover Size (px)",
                                   value=8,
                                   start=4,
                                   end=30,
                                   step=1)
        hover_size_slider.on_change('value', self.on_hover_size_change)
        controls.append(hover_size_slider)

        controls.append(self.model.legend_side_vbox)

        # add map components
        basemap_select = Select.create(name='Basemap',
                                       value='Imagery',
                                       options=self.model.basemaps)
        basemap_select.on_change('value', self.on_basemap_change)

        image_opacity_slider = Slider(title="Opacity",
                                      value=100,
                                      start=0,
                                      end=100,
                                      step=1)
        image_opacity_slider.on_change('value',
                                       self.on_image_opacity_slider_change)

        basemap_opacity_slider = Slider(title="Basemap Opacity",
                                        value=100,
                                        start=0,
                                        end=100,
                                        step=1)
        basemap_opacity_slider.on_change('value',
                                         self.on_basemap_opacity_slider_change)

        show_labels_chk = CheckboxGroup(labels=["Show Labels"], active=[0])
        show_labels_chk.on_click(self.on_labels_change)

        map_controls = [
            basemap_select, basemap_opacity_slider, image_opacity_slider,
            show_labels_chk
        ]

        self.controls = VBox(width=200, height=600, children=controls)
        self.map_controls = HBox(width=self.fig.plot_width,
                                 children=map_controls)
        self.map_area = VBox(width=self.fig.plot_width,
                             children=[
                                 self.map_controls, self.fig,
                                 self.model.legend_bottom_vbox
                             ])
        self.layout = HBox(width=1366, children=[self.controls, self.map_area])
Ejemplo n.º 11
0
    def create_layout(self):

        # create figure
        self.x_range = Range1d(start=self.model.map_extent[0],
                               end=self.model.map_extent[2],
                               bounds=None)
        self.y_range = Range1d(start=self.model.map_extent[1],
                               end=self.model.map_extent[3],
                               bounds=None)

        self.fig = Figure(tools='wheel_zoom,pan',
                          x_range=self.x_range,
                          y_range=self.y_range)
        self.fig.plot_height = 560
        self.fig.plot_width = 800
        self.fig.axis.visible = False

        # add tiled basemap
        self.tile_source = WMTSTileSource(url=self.model.basemap)
        self.tile_renderer = TileRenderer(tile_source=self.tile_source)
        self.fig.renderers.append(self.tile_renderer)

        # add datashader layer
        self.image_source = ImageSource(
            url=self.model.service_url,
            extra_url_vars=self.model.shader_url_vars)
        self.image_renderer = DynamicImageRenderer(
            image_source=self.image_source)
        self.fig.renderers.append(self.image_renderer)

        # add label layer
        self.label_source = WMTSTileSource(url=self.model.labels_url)
        self.label_renderer = TileRenderer(tile_source=self.label_source)
        self.fig.renderers.append(self.label_renderer)

        # add ui components
        axes_select = Select.create(name='Axes', options=self.model.axes)
        axes_select.on_change('value', self.on_axes_change)

        field_select = Select.create(name='Field', options=self.model.fields)
        field_select.on_change('value', self.on_field_change)

        aggregate_select = Select.create(
            name='Aggregate', options=self.model.aggregate_functions)
        aggregate_select.on_change('value', self.on_aggregate_change)

        transfer_select = Select.create(name='Transfer Function',
                                        options=self.model.transfer_functions)
        transfer_select.on_change('value', self.on_transfer_function_change)

        basemap_select = Select.create(name='Basemap',
                                       value='Toner',
                                       options=self.model.basemaps)
        basemap_select.on_change('value', self.on_basemap_change)

        image_opacity_slider = Slider(title="Opacity",
                                      value=100,
                                      start=0,
                                      end=100,
                                      step=1)
        image_opacity_slider.on_change('value',
                                       self.on_image_opacity_slider_change)

        basemap_opacity_slider = Slider(title="Basemap Opacity",
                                        value=100,
                                        start=0,
                                        end=100,
                                        step=1)
        basemap_opacity_slider.on_change('value',
                                         self.on_basemap_opacity_slider_change)

        show_labels_chk = CheckboxGroup(labels=["Show Labels"], active=[0])
        show_labels_chk.on_click(self.on_labels_change)

        controls = [
            axes_select, field_select, aggregate_select, transfer_select
        ]

        map_controls = [
            basemap_select, basemap_opacity_slider, image_opacity_slider,
            show_labels_chk
        ]

        self.controls = VBox(width=200, height=600, children=controls)
        self.map_controls = HBox(width=self.fig.plot_width,
                                 children=map_controls)
        self.map_area = VBox(width=self.fig.plot_width,
                             children=[self.map_controls, self.fig])
        self.layout = HBox(width=1024, children=[self.controls, self.map_area])
Ejemplo n.º 12
0
    def create_layout(self):

        # create figure
        self.x_range = Range1d(start=self.model.map_extent[0],
                               end=self.model.map_extent[2], bounds=None)
        self.y_range = Range1d(start=self.model.map_extent[1],
                               end=self.model.map_extent[3], bounds=None)

        self.fig = Figure(tools='wheel_zoom,pan', x_range=self.x_range,
                          y_range=self.y_range)
        self.fig.plot_height = 560
        self.fig.plot_width = 800
        self.fig.axis.visible = False

        # add tiled basemap
        self.tile_source = WMTSTileSource(url=self.model.basemap)
        self.tile_renderer = TileRenderer(tile_source=self.tile_source)
        self.fig.renderers.append(self.tile_renderer)

        # add datashader layer
        self.image_source = ImageSource(url=self.model.service_url,
                                        extra_url_vars=self.model.shader_url_vars)
        self.image_renderer = DynamicImageRenderer(image_source=self.image_source)
        self.fig.renderers.append(self.image_renderer)

        # add label layer
        self.label_source = WMTSTileSource(url=self.model.labels_url)
        self.label_renderer = TileRenderer(tile_source=self.label_source)
        self.fig.renderers.append(self.label_renderer)

        # add ui components
        axes_select = Select.create(name='Axes',
                                    options=self.model.axes)
        axes_select.on_change('value', self.on_axes_change)

        field_select = Select.create(name='Field', options=self.model.fields)
        field_select.on_change('value', self.on_field_change)

        aggregate_select = Select.create(name='Aggregate',
                                         options=self.model.aggregate_functions)
        aggregate_select.on_change('value', self.on_aggregate_change)

        transfer_select = Select.create(name='Transfer Function',
                                        options=self.model.transfer_functions)
        transfer_select.on_change('value', self.on_transfer_function_change)

        basemap_select = Select.create(name='Basemap', value='Toner',
                                       options=self.model.basemaps)
        basemap_select.on_change('value', self.on_basemap_change)

        image_opacity_slider = Slider(title="Opacity", value=100, start=0,
                                      end=100, step=1)
        image_opacity_slider.on_change('value', self.on_image_opacity_slider_change)

        basemap_opacity_slider = Slider(title="Basemap Opacity", value=100, start=0,
                                        end=100, step=1)
        basemap_opacity_slider.on_change('value', self.on_basemap_opacity_slider_change)

        show_labels_chk = CheckboxGroup(labels=["Show Labels"], active=[0])
        show_labels_chk.on_click(self.on_labels_change)


        controls = [axes_select, field_select, aggregate_select,
                    transfer_select]

        map_controls = [basemap_select, basemap_opacity_slider,
                        image_opacity_slider, show_labels_chk]

        self.controls = VBox(width=200, height=600, children=controls)
        self.map_controls = HBox(width=self.fig.plot_width, children=map_controls)
        self.map_area = VBox(width=self.fig.plot_width, children=[self.map_controls, self.fig])
        self.layout = HBox(width=1024, children=[self.controls, self.map_area])
Ejemplo n.º 13
0
class EeghdfBrowser:
    """
    take an hdfeeg file and allow for browsing of the EEG signal

    just use the raw hdf file and conventions for now

    """
    def __init__(
        self,
        eeghdf_file,
        page_width_seconds=10.0,
        start_seconds=-1,
        montage="trace",
        montage_options={},
        yscale=1.0,
        plot_width=950,
        plot_height=600,
    ):
        """
        @eegfile is an eeghdf.Eeghdf() class instance representing the file
        @montage is either a string in the standard list or a montageview factory
        @eeghdf_file - an eeghdf.Eeeghdf instance
        @page_width_seconds = how big to make the view in seconds
        @montage - montageview (class factory) OR a string that identifies a default montage (may want to change this to a factory function 
        @start_seconds - center view on this point in time

        BTW 'trace' is what NK calls its 'as recorded' montage - might be better to call 'raw', 'default' or 'as recorded'
        """

        self.eeghdf_file = eeghdf_file
        self.update_eeghdf_file(eeghdf_file, montage, montage_options)

        # display related
        self.page_width_seconds = page_width_seconds

        ## bokeh related

        self.page_width_secs = page_width_seconds
        if start_seconds < 0:
            self.loc_sec = (
                page_width_seconds / 2.0
            )  # default location in file by default at start if possible
        else:
            self.loc_sec = start_seconds

        # self.init_kwargs = kwargs

        # other ones
        self.yscale = yscale
        self.ui_plot_width = plot_width
        self.ui_plot_height = plot_height

        self.bk_handle = None
        self.fig = None

        self.update_title()

        self.num_rows, self.num_samples = self.signals.shape
        self.line_glyphs = []  # not used?
        self.multi_line_glyph = None

        self.ch_start = 0
        self.ch_stop = self.current_montage_instance.shape[0]

        ####### set up filter cache: first try
        self.current_hp_filter = None
        self.current_lp_filter = None
        self._highpass_cache = OrderedDict()

        self._highpass_cache["None"] = None

        self._highpass_cache["0.1 Hz"] = esfilters.fir_highpass_firwin_ff(
            self.fs, cutoff_freq=0.1, numtaps=int(self.fs))

        self._highpass_cache["0.3 Hz"] = esfilters.fir_highpass_firwin_ff(
            self.fs, cutoff_freq=0.3, numtaps=int(self.fs))

        # ff = esfilters.fir_highpass_remez_zerolag(fs=self.fs, cutoff_freq=1.0, transition_width=0.5, numtaps=int(2*self.fs))
        ff = esfilters.fir_highpass_firwin_ff(fs=self.fs,
                                              cutoff_freq=1.0,
                                              numtaps=int(2 * self.fs))
        self._highpass_cache["1 Hz"] = ff
        # ff = esfilters.fir_highpass_remez_zerolag(fs=self.fs, cutoff_freq=5.0, transition_width=2.0, numtaps=int(0.2*self.fs))
        ff = esfilters.fir_highpass_firwin_ff(fs=self.fs,
                                              cutoff_freq=5.0,
                                              numtaps=int(0.2 * self.fs))
        self._highpass_cache["5 Hz"] = ff

        firstkey = "0.3 Hz"  # list(self._highpass_cache.keys())[0]
        self.current_hp_filter = self._highpass_cache[firstkey]

        self._lowpass_cache = OrderedDict()
        self._lowpass_cache["None"] = None
        self._lowpass_cache["15 Hz"] = esfilters.fir_lowpass_firwin_ff(
            fs=self.fs, cutoff_freq=15.0, numtaps=int(self.fs / 2.0))
        self._lowpass_cache["30 Hz"] = esfilters.fir_lowpass_firwin_ff(
            fs=self.fs, cutoff_freq=30.0, numtaps=int(self.fs / 4.0))
        self._lowpass_cache["50 Hz"] = esfilters.fir_lowpass_firwin_ff(
            fs=self.fs, cutoff_freq=50.0, numtaps=int(self.fs / 4.0))
        self._lowpass_cache["70 Hz"] = esfilters.fir_lowpass_firwin_ff(
            fs=self.fs, cutoff_freq=70.0, numtaps=int(self.fs / 4.0))

        self._notch_filter = esfilters.notch_filter_iir_ff(notch_freq=60.0,
                                                           fs=self.fs,
                                                           Q=60)
        self.current_notch_filter = None

    @property
    def signals(self):
        return self.eeghdf_file.phys_signals

    def update_eeghdf_file(self,
                           eeghdf_file,
                           montage="trace",
                           montage_options={}):
        self.eeghdf_file = eeghdf_file
        hdf = eeghdf_file.hdf
        rec = hdf["record-0"]
        self.fs = rec.attrs["sample_frequency"]
        # self.signals = rec['signals']
        blabels = rec["signal_labels"]  # byte labels
        # self.electrode_labels = [str(ss,'ascii') for ss in blabels]
        self.electrode_labels = eeghdf_file.electrode_labels
        # fill in any missing ones
        if len(self.electrode_labels) < eeghdf_file.phys_signals.shape[0]:
            d = eeghdf_file.phys_signals.shape[0] - len(self.electrode_labels)
            ll = len(self.electrode_labels)
            suppl = [str(ii) for ii in range(ll, ll + d)]
            self.electrode_labels += suppl
            print("extending electrode lables:", suppl)

        # reference labels are used for montages, since this is an eeghdf file, it can provide these

        self.ref_labels = eeghdf_file.shortcut_elabels

        if not montage_options:
            # then use builtins and/or ones in the file
            montage_options = montageview.MONTAGE_BUILTINS.copy()
            # print('starting build of montage options', montage_options)

            # montage_options = eeghdf_file.get_montages()

        # defines self.current_montage_instance
        self.current_montage_instance = None
        if type(montage) == str:  # then we have some work to do
            if montage in montage_options:

                self.current_montage_instance = montage_options[montage](
                    self.ref_labels)
            else:
                raise Exception("unrecognized montage: %s" % montage)
        else:
            if montage:  # is a class
                self.current_montage_instance = montage(self.ref_labels)
                montage_options[self.current_montage_instance.name] = montage
            else:  # use default

                self.current_montage_instance = montage_options[0](
                    self.ref_labels)

        assert self.current_montage_instance
        try:  # to update ui display
            self.ui_montage_dropdown.value = self.current_montage_instance.name
        except AttributeError:
            # guess is not yet instantiated
            pass

        self.montage_options = montage_options  # save the montage_options for later
        self.update_title()
        # note this does not do any plotting or update the plot

    def update_title(self):
        self.title = "hdf %s - montage: %s" % (
            self.eeghdf_file.hdf.filename,
            self.current_montage_instance.full_name
            if self.current_montage_instance else "",
        )

    #         if showchannels=='all':
    #             self.ch_start = 0  # change this to a list of channels for fancy slicing
    #             if montage:
    #                 self.ch_stop = montage.shape[0] # all the channels in the montage
    #             self.ch_stop = signals.shape[0] # all the channels in the original signal
    #         else:
    #             self.ch_start, self.ch_stop = showchannels
    #         self.num_rows, self.num_samples = signals.shape
    #         self.line_glyphs = []
    #         self.multi_line_glyph = None

    def plot(self):
        """create a Bokeh figure to hold EEG"""
        self.fig = self.show_montage_centered(
            self.signals,
            self.loc_sec,
            page_width_sec=self.page_width_secs,
            chstart=0,
            chstop=self.current_montage_instance.shape[0],
            fs=self.fs,
            ylabels=self.current_montage_instance.montage_labels,
            yscale=self.yscale,
            montage=self.current_montage_instance,
        )
        self.fig.xaxis.axis_label = "seconds"
        # make the xgrid mark every second
        self.fig.xgrid.ticker = SingleIntervalTicker(
            interval=1.0)  #  bokeh.models.tickers.SingleIntervalTicker
        return self.fig

    def show_for_bokeh_app(self):
        """try running intside a bokeh app, so don't need notebook stuff"""
        self.plot()

    def bokeh_show(self):
        """
        meant to run in notebook so sets up handles
        """
        self.plot()
        self.register_top_bar_ui()  # create the buttons
        self.bk_handle = bokeh.plotting.show(self.fig, notebook_handle=True)
        self.register_bottom_bar_ui()

    def update(self):
        """
        updates the data in the plot 
        so that it will show up
        can either use bokeh.io.push_notebook()
        or panel.pane.Bokeh(..)
        to handle event loop
        """
        goto_sample = int(self.fs * self.loc_sec)
        page_width_samples = int(self.page_width_secs * self.fs)
        hw = half_width_epoch_sample = int(page_width_samples / 2)
        s0 = limit_sample_check(goto_sample - hw, self.signals)
        s1 = limit_sample_check(goto_sample + hw, self.signals)
        window_samples = s1 - s0
        signal_view = self.signals[:, s0:s1]
        inmontage_view = np.dot(self.current_montage_instance.V.data,
                                signal_view)

        data = inmontage_view[self.ch_start:self.ch_stop, :]  # note transposed
        numRows = inmontage_view.shape[0]
        ########## do filtering here ############
        # start primative filtering
        if self.current_notch_filter:
            for ii in range(numRows):
                data[ii, :] = self.current_notch_filter(data[ii, :])

        if self.current_hp_filter:
            for ii in range(numRows):
                data[ii, :] = self.current_hp_filter(data[ii, :])
        if self.current_lp_filter:
            for ii in range(numRows):
                data[ii, :] = self.current_lp_filter(data[ii, :])

        ## end filtering
        t = (self.page_width_secs * np.arange(window_samples, dtype=float) /
             window_samples)
        t = t + s0 / self.fs  # t = t + start_time
        # t = t[:s1-s0]
        ## this is not quite right if ch_start is not 0
        xs = [t for ii in range(numRows)]
        ys = [
            self.yscale * data[ii, :] + self.ticklocs[ii]
            for ii in range(numRows)
        ]
        # print('len(xs):', len(xs), 'len(ys):', len(ys))

        # is this the best way to update the data? should it be done both at once
        # {'xs':xs, 'ys':ys}
        self.data_source.data.update(dict(xs=xs,
                                          ys=ys))  # could just use equals?
        # old way
        # self.data_source.data['xs'] = xs
        # self.data_source.data['ys'] = ys

        #self.push_notebook()
        # do pane.Bokeh::param.trigger('object') on pane holding EEG waveform plot
        # in notebook updates without a trigger

    def stackplot_t(
        self,
        tarray,
        seconds=None,
        start_time=None,
        ylabels=None,
        yscale=1.0,
        topdown=True,
        **kwargs,
    ):
        """
        will plot a stack of traces one above the other assuming
        @tarray is an nd-array like object with format
        tarray.shape =  numSamples, numRows

        @seconds = with of plot in seconds for labeling purposes (optional)
        @start_time is start time in seconds for the plot (optional)

        @ylabels a list of labels for each row ("channel") in marray
        @yscale with increase (mutiply) the signals in each row by this amount
        """
        data = tarray
        numSamples, numRows = tarray.shape
        # data = np.random.randn(numSamples,numRows) # test data
        # data.shape = numSamples, numRows
        if seconds:
            t = seconds * np.arange(numSamples, dtype=float) / numSamples

            if start_time:
                t = t + start_time
                xlm = (start_time, start_time + seconds)
            else:
                xlm = (0, seconds)

        else:
            t = np.arange(numSamples, dtype=float)
            xlm = (0, numSamples)

        ticklocs = []
        if not "plot_width" in kwargs:
            kwargs["plot_width"] = (
                self.ui_plot_width
            )  # 950  # a default width that is wider but can just fit in jupyter, not sure if plot_width is preferred
        if not "plot_height" in kwargs:
            kwargs["plot_height"] = self.ui_plot_height

        if not self.fig:
            # print('creating figure')
            # bokeh.plotting.figure creases a subclass of plot
            fig = bokeh.plotting.figure(
                title=self.title,
                # tools="pan,box_zoom,reset,previewsave,lasso_select,ywheel_zoom",
                tools="pan,box_zoom,reset,lasso_select,ywheel_zoom",
                **kwargs,
            )  # subclass of Plot that simplifies plot creation
            self.fig = fig

        ## xlim(*xlm)
        # xticks(np.linspace(xlm, 10))
        dmin = data.min()
        dmax = data.max()
        dr = (dmax - dmin) * 0.7  # Crowd them a bit.
        y0 = dmin
        y1 = (numRows - 1) * dr + dmax
        ## ylim(y0, y1)

        ticklocs = [ii * dr for ii in range(numRows)]
        bottom = -dr / 0.7
        top = (numRows - 1) * dr + dr / 0.7
        self.y_range = Range1d(bottom, top)
        self.fig.y_range = self.y_range

        if topdown == True:
            ticklocs.reverse()  # inplace

        # print("ticklocs:", ticklocs)

        offsets = np.zeros((numRows, 2), dtype=float)
        offsets[:, 1] = ticklocs
        self.ticklocs = ticklocs
        self.time = t
        ## segs = []
        # note could also duplicate time axis then use p.multi_line
        # line_glyphs = []
        # for ii in range(numRows):
        #     ## segs.append(np.hstack((t[:, np.newaxis], yscale * data[:, i, np.newaxis])))
        #     line_glyphs.append(
        #         fig.line(t[:],yscale * data[:, ii] + offsets[ii, 1] ) # adds line glyphs to figure
        #     )

        #     # print("segs[-1].shape:", segs[-1].shape)
        #     ##ticklocs.append(i * dr)
        # self.line_glyphs = line_glyphs

        ########## do filtering here ############
        # start primative filtering
        # remember we are in the stackplot_t so channels and samples are flipped -- !!! eliminate this junk
        if self.current_notch_filter:
            for ii in range(numRows):
                data[ii, :] = self.current_notch_filter(data[ii, :])

        if self.current_hp_filter:
            # print("doing filtering")
            for ii in range(numRows):
                data[:, ii] = self.current_hp_filter(data[:, ii])

        if self.current_lp_filter:
            for ii in range(numRows):
                data[ii, :] = self.current_lp_filter(data[ii, :])

        ## end filtering

        ## instead build a data_dict and use datasource with multi_line
        xs = [t for ii in range(numRows)]
        ys = [yscale * data[:, ii] + ticklocs[ii] for ii in range(numRows)]

        self.multi_line_glyph = self.fig.multi_line(
            xs=xs, ys=ys)  # , line_color='firebrick')
        self.data_source = self.multi_line_glyph.data_source

        # set the yticks to use axes coords on the y axis
        if not ylabels:
            ylabels = ["%d" % ii for ii in range(numRows)]
        ylabel_dict = dict(zip(ticklocs, ylabels))
        # print('ylabel_dict:', ylabel_dict)
        self.fig.yaxis.ticker = FixedTicker(
            ticks=ticklocs)  # can also short cut to give list directly
        self.fig.yaxis.formatter = FuncTickFormatter(code="""
            var labels = %s;
            return labels[tick];
        """ % ylabel_dict)
        ## ax.set_yticklabels(ylabels)

        ## xlabel('time (s)')

        return self.fig

    def update_plot_after_montage_change(self):

        self.fig.title.text = self.title
        goto_sample = int(self.fs * self.loc_sec)
        page_width_samples = int(self.page_width_secs * self.fs)

        hw = half_width_epoch_sample = int(page_width_samples / 2)
        s0 = limit_sample_check(goto_sample - hw, self.signals)
        s1 = limit_sample_check(goto_sample + hw, self.signals)

        window_samples = s1 - s0
        signal_view = self.signals[:, s0:s1]
        inmontage_view = np.dot(self.current_montage_instance.V.data,
                                signal_view)
        self.ch_start = 0
        self.ch_stop = inmontage_view.shape[0]

        numRows = inmontage_view.shape[0]  # ???
        # print('numRows: ', numRows)

        data = inmontage_view[self.ch_start:self.ch_stop, :]  # note transposed
        # really just need to reset the labels

        ticklocs = []

        ## xlim(*xlm)
        # xticks(np.linspace(xlm, 10))
        dmin = data.min()
        dmax = data.max()
        dr = (dmax - dmin) * 0.7  # Crowd them a bit.
        y0 = dmin
        y1 = (numRows - 1) * dr + dmax
        ## ylim(y0, y1)

        ticklocs = [ii * dr for ii in range(numRows)]
        ticklocs.reverse()  # inplace
        bottom = -dr / 0.7
        top = (numRows - 1) * dr + dr / 0.7
        self.y_range.start = bottom
        self.y_range.end = top
        # self.fig.y_range = Range1d(bottom, top)

        # print("ticklocs:", ticklocs)

        offsets = np.zeros((numRows, 2), dtype=float)
        offsets[:, 1] = ticklocs
        self.ticklocs = ticklocs
        # self.time = t

        ylabels = self.current_montage_instance.montage_labels
        ylabel_dict = dict(zip(ticklocs, ylabels))
        # print('ylabel_dict:', ylabel_dict)
        self.fig.yaxis.ticker = FixedTicker(
            ticks=ticklocs)  # can also short cut to give list directly
        self.fig.yaxis.formatter = FuncTickFormatter(code="""
            var labels = %s;
            return labels[tick];
        """ % ylabel_dict)

        ## experiment with clearing the data source
        # self.data_source.data.clear() # vs .update() ???

    def stackplot(
        self,
        marray,
        seconds=None,
        start_time=None,
        ylabels=None,
        yscale=1.0,
        topdown=True,
        **kwargs,
    ):
        """
        will plot a stack of traces one above the other assuming
        @marray contains the data you want to plot
        marray.shape = numRows, numSamples

        @seconds = with of plot in seconds for labeling purposes (optional)
        @start_time is start time in seconds for the plot (optional)

        @ylabels a list of labels for each row ("channel") in marray
        @yscale with increase (mutiply) the signals in each row by this amount
        """
        tarray = np.transpose(marray)
        return self.stackplot_t(
            tarray,
            seconds=seconds,
            start_time=start_time,
            ylabels=ylabels,
            yscale=yscale,
            topdown=True,
            **kwargs,
        )

    def show_epoch_centered(
        self,
        signals,
        goto_sec,
        page_width_sec,
        chstart,
        chstop,
        fs,
        ylabels=None,
        yscale=1.0,
    ):
        """
        @signals array-like object with signals[ch_num, sample_num]
        @goto_sec where to go in the signal to show the feature
        @page_width_sec length of the window to show in secs
        @chstart   which channel to start
        @chstop    which channel to end
        @labels_by_channel
        @yscale
        @fs sample frequency (num samples per second)
        """

        goto_sample = int(fs * goto_sec)
        hw = half_width_epoch_sample = int(page_width_sec * fs / 2)

        # plot epochs of width page_width_sec centered on (multiples in DE)
        ch0, ch1 = chstart, chstop

        ptepoch = int(page_width_sec * fs)

        s0 = limit_sample_check(goto_sample - hw, signals)
        s1 = limit_sample_check(goto_sample + hw, signals)
        duration = (s1 - s0) / fs
        start_time_sec = s0 / fs

        return self.stackplot(
            signals[ch0:ch1, s0:s1],
            start_time=start_time_sec,
            seconds=duration,
            ylabels=ylabels[ch0:ch1],
            yscale=yscale,
        )

    def show_montage_centered(
        self,
        signals,
        goto_sec,
        page_width_sec,
        chstart,
        chstop,
        fs,
        ylabels=None,
        yscale=1.0,
        montage=None,
        topdown=True,
        **kwargs,
    ):
        """
        plot an eeg segment using current montage, center the plot at @goto_sec
        with @page_width_sec shown

        @signals array-like object with signals[ch_num, sample_num]

        @goto_sec where to go in the signal to show the feature
        @page_width_sec length of the window to show in secs
        @chstart   which channel to start
        @chstop    which channel to end

        @fs sample frequency (num samples per second)

        @ylabels a list of labels for each row ("channel") in marray
        @yscale with increase (mutiply) the signals in each row by this amount
        @montage instance 

        """

        goto_sample = int(fs * goto_sec)
        hw = half_width_epoch_sample = int(page_width_sec * fs / 2)

        # plot epochs of width page_width_sec centered on (multiples in DE)
        ch0, ch1 = chstart, chstop

        ptepoch = int(page_width_sec * fs)

        s0 = limit_sample_check(goto_sample - hw, signals)
        s1 = limit_sample_check(goto_sample + hw, signals)
        duration_sec = (s1 - s0) / fs
        start_time_sec = s0 / fs

        # signals[ch0:ch1, s0:s1]
        signal_view = signals[:, s0:s1]
        inmontage_view = np.dot(montage.V.data, signal_view)
        rlabels = montage.montage_labels

        data = inmontage_view[chstart:chstop, :]
        numRows, numSamples = data.shape
        # data = np.random.randn(numSamples,numRows) # test data
        # data.shape =  numRows, numSamples

        t = duration_sec * np.arange(numSamples, dtype=float) / numSamples

        t = t + start_time_sec  # shift over
        xlm = (start_time_sec, start_time_sec + duration_sec)

        ticklocs = []
        if not "plot_width" in kwargs:
            kwargs["plot_width"] = (
                self.ui_plot_width
            )  # 950  # a default width that is wider but can just fit in jupyter, not sure if plot_width is preferred
        if not "plot_height" in kwargs:
            kwargs["plot_height"] = self.ui_plot_height

        if not self.fig:
            # print('creating figure')
            fig = bokeh.plotting.figure(
                title=self.title,
                # tools="pan,box_zoom,reset,previewsave,lasso_select,ywheel_zoom",
                #tools="pan,box_zoom,reset,lasso_select,ywheel_zoom",
                tools="crosshair",
                **kwargs,
            )  # subclass of Plot that simplifies plot creation
            self.fig = fig

        ## xlim(*xlm)
        # xticks(np.linspace(xlm, 10))
        dmin = data.min()
        dmax = data.max()
        dr = (dmax - dmin) * 0.7  # Crowd them a bit.
        y0 = dmin
        y1 = (numRows - 1) * dr + dmax
        ## ylim(y0, y1)

        ticklocs = [ii * dr for ii in range(numRows)]
        bottom = -dr / 0.7
        top = (numRows - 1) * dr + dr / 0.7
        self.y_range = Range1d(bottom, top)
        self.fig.y_range = self.y_range

        if topdown == True:
            ticklocs.reverse()  # inplace

        # print("ticklocs:", ticklocs)

        offsets = np.zeros((numRows, 2), dtype=float)
        offsets[:, 1] = ticklocs
        self.ticklocs = ticklocs
        self.time = t
        ## segs = []
        # note could also duplicate time axis then use p.multi_line
        # line_glyphs = []
        # for ii in range(numRows):
        #     ## segs.append(np.hstack((t[:, np.newaxis], yscale * data[:, i, np.newaxis])))
        #     line_glyphs.append(
        #         fig.line(t[:],yscale * data[:, ii] + offsets[ii, 1] ) # adds line glyphs to figure
        #     )

        #     # print("segs[-1].shape:", segs[-1].shape)
        #     ##ticklocs.append(i * dr)
        # self.line_glyphs = line_glyphs

        ########## do filtering here ############
        # start primative filtering
        # remember we are in the stackplot_t so channels and samples are flipped -- !!! eliminate this junk
        if self.current_notch_filter:
            for ii in range(numRows):
                data[ii, :] = self.current_notch_filter(data[ii, :])

        if self.current_hp_filter:
            # print("doing filtering")
            for ii in range(numRows):
                data[ii, :] = self.current_hp_filter(data[ii, :])

        if self.current_lp_filter:
            for ii in range(numRows):
                data[ii, :] = self.current_lp_filter(data[ii, :])

        ## end filtering

        ## instead build a data_dict and use datasource with multi_line
        xs = [t for ii in range(numRows)]
        ys = [yscale * data[ii, :] + ticklocs[ii] for ii in range(numRows)]

        self.multi_line_glyph = self.fig.multi_line(
            xs=xs, ys=ys)  # , line_color='firebrick')
        self.data_source = self.multi_line_glyph.data_source

        # set the yticks to use axes coords on the y axis
        if not ylabels:
            ylabels = ["%d" % ii for ii in range(numRows)]
        ylabel_dict = dict(zip(ticklocs, ylabels))
        # print('ylabel_dict:', ylabel_dict)
        self.fig.yaxis.ticker = FixedTicker(
            ticks=ticklocs)  # can also short cut to give list directly
        self.fig.yaxis.formatter = FuncTickFormatter(code="""
            var labels = %s;
            return labels[tick];
        """ % ylabel_dict)
        return self.fig

    def register_top_bar_ui(self):

        # mlayout = ipywidgets.Layout()
        # mlayout.width = "15em"
        self.ui_montage_dropdown = Select(
            # options={'One': 1, 'Two': 2, 'Three': 3},
            options=self.montage_options.keys(),  # or .montage_optins.keys()
            value=self.current_montage_instance.name,
            title="Montage:",
            # layout=mlayout, # set width to "15em"
        )

        def on_dropdown_change(attr, oldvalue, newvalue, parent=self):
            print(
                f"on_dropdown_change: {attr}, {oldvalue}, {newvalue}, {parent}"
            )
            if change["name"] == "value":  # the value changed
                if change["new"] != change["old"]:
                    # print('*** should change the montage to %s from %s***' % (change['new'], change['old']))
                    parent.update_montage(
                        change["new"]
                    )  # change to the montage keyed by change['new']
                    parent.update_plot_after_montage_change()
                    parent.update()  #

        self.ui_montage_dropdown.on_change("value", on_dropdown_change)

        # flayout = ipywidgets.Layout()
        # flayout.width = "12em"
        self.ui_low_freq_filter_dropdown = ipywidgets.Dropdown(
            # options = ['None', '0.1 Hz', '0.3 Hz', '1 Hz', '5 Hz', '15 Hz',
            #           '30 Hz', '50 Hz', '100 Hz', '150Hz'],
            options=self._highpass_cache.keys(),
            value="0.3 Hz",
            description="LF",
            layout=flayout,
        )

        def lf_dropdown_on_change(change, parent=self):
            # print('change observed: %s' % pprint.pformat(change))
            if change["name"] == "value":  # the value changed
                if change["new"] != change["old"]:
                    # print('*** should change the filter to %s from %s***' % (change['new'], change['old']))
                    parent.current_hp_filter = parent._highpass_cache[
                        change["new"]]
                    parent.update()  #

        self.ui_low_freq_filter_dropdown.observe(lf_dropdown_on_change)

        ###

        self.ui_high_freq_filter_dropdown = ipywidgets.Dropdown(
            # options = ['None', '15 Hz', '30 Hz', '50 Hz', '70Hz', '100 Hz', '150Hz', '300 Hz'],
            options=self._lowpass_cache.keys(),
            # value = '70Hz',
            description="HF",
            layout=flayout,
        )

        def hf_dropdown_on_change(change, parent=self):
            if change["name"] == "value":  # the value changed
                if change["new"] != change["old"]:
                    # print('*** should change the filter to %s from %s***' % (change['new'], change['old']))
                    self.current_lp_filter = self._lowpass_cache[change["new"]]
                    self.update()  #

        self.ui_high_freq_filter_dropdown.observe(hf_dropdown_on_change)

        def go_to_handler(change, parent=self):
            # print("change:", change)
            if change["name"] == "value":
                self.loc_sec = change["new"]
                self.update()

        self.ui_notch_option = ipywidgets.Checkbox(value=False,
                                                   description="60Hz Notch",
                                                   disabled=False)

        def notch_change(change):
            if change["name"] == "value":
                if change["new"]:
                    self.current_notch_filter = self._notch_filter
                else:
                    self.current_notch_filter = None
                self.update()

        self.ui_notch_option.observe(notch_change)

        self.ui_gain_bounded_float = ipywidgets.BoundedFloatText(
            value=1.0,
            min=0.001,
            max=1000.0,
            step=0.1,
            description="gain",
            disabled=False,
            continuous_update=False,  # only trigger when done
            layout=flayout,
        )

        def ui_gain_on_change(change, parent=self):
            if change["name"] == "value":
                if change["new"] != change["old"]:
                    self.yscale = float(change["new"])
                    self.update()

        self.ui_gain_bounded_float.observe(ui_gain_on_change)
        top_bar_layout = bokeh.layouts.row(
            self.ui_montage_dropdown,
            self.ui_low_freq_filter_dropdown,
            self.ui_high_freq_filter_dropdown,
            self.ui_notch_option,
            self.ui_gain_bounded_float,
        )
        return top_bar_layout
        # display(
        #     ipywidgets.HBox(
        #         [
        #             self.ui_montage_dropdown,
        #             self.ui_low_freq_filter_dropdown,
        #             self.ui_high_freq_filter_dropdown,
        #             self.ui_notch_option,
        #             self.ui_gain_bounded_float,
        #         ]
        #     )
        # )

    def register_top_bar_ui(self):
        # mlayout = ipywidgets.Layout()
        # mlayout.width = "15em"
        self.ui_montage_dropdown = Select(
            # options={'One': 1, 'Two': 2, 'Three': 3},
            options=list(
                self.montage_options.keys()),  # or .montage_optins.keys()
            value=str(self.current_montage_instance.name),
            title="Montage:",
            # layout=mlayout, # set width to "15em"
        )

        def on_dropdown_change(attr, oldvalue, newvalue, parent=self):
            # print(
            #     f"on_dropdown_change: {repr(attr)}, {repr(oldvalue)}, {repr(newvalue)}, {parent}"
            # )

            parent.update_montage(newvalue)
            parent.update_plot_after_montage_change()
            parent.update()

        self.ui_montage_dropdown.on_change("value", on_dropdown_change)

        # flayout = ipywidgets.Layout()
        # flayout.width = "12em"
        self.ui_low_freq_filter_dropdown = Select(
            # options = ['None', '0.1 Hz', '0.3 Hz', '1 Hz', '5 Hz', '15 Hz',
            #           '30 Hz', '50 Hz', '100 Hz', '150Hz'],
            options=list(self._highpass_cache.keys()),
            value="0.3 Hz",
            title="LF",
            max_width=
            150,  # only problem is that I suspect this is pixels, maybe use panel?css
            # layout=flayout, # see https://panel.holoviz.org/user_guide/Customization.html
        )

        def lf_dropdown_on_change(attr, oldvalue, newvalue, parent=self):
            # print(
            #     f"on_dropdown_change: {repr(attr)}, {repr(oldvalue)}, {repr(newvalue)}, {parent}"
            # )
            parent.current_hp_filter = parent._highpass_cache[newvalue]
            parent.update()  #

        self.ui_low_freq_filter_dropdown.on_change("value",
                                                   lf_dropdown_on_change)

        self.ui_high_freq_filter_dropdown = Select(
            # options = ['None', '15 Hz', '30 Hz', '50 Hz', '70Hz', '100 Hz', '150Hz', '300 Hz'],
            options=list(self._lowpass_cache.keys()),
            value="None",
            title="HF",
            max_width=150,
            # layout=flayout,
        )

        def hf_dropdown_on_change(attr, oldvalue, newvalue, parent=self):
            # print(
            #     f"on_dropdown_change: {repr(attr)}, {repr(oldvalue)}, {repr(newvalue)}, {parent}"
            # )
            # if change["name"] == "value":  # the value changed
            #    if change["new"] != change["old"]:
            #        # print('*** should change the filter to %s from %s***' % (change['new'], change['old']))
            self.current_lp_filter = self._lowpass_cache[newvalue]
            self.update()  #

        self.ui_high_freq_filter_dropdown.on_change("value",
                                                    hf_dropdown_on_change)

        self.ui_notch_option = CheckboxGroup(
            labels=["60Hz Notch"]
            #, "50Hz Notch"], max_width=100,  # disabled=False
        )

        def notch_change(newvalue, parent=self):
            #print(f"on_dropdown_change: {repr(newvalue)}, {parent}")
            if newvalue == [0]:
                self.current_notch_filter = self._notch_filter
            elif newvalue == []:
                self.current_notch_filter = None
            self.update()

        self.ui_notch_option.on_click(notch_change)

        self.ui_gain_bounded_float = Spinner(
            value=1.0,
            # min=0.001,
            # max=1000.0,
            step=
            0.1,  # Interval(interval_type: (Int, Float), start, end, default=None, help=None)
            # page_step_multiplier=2.0, # may be supported in bokeh 2.2
            title="gain",
            # value_throtted=(float|int)
            # disabled=False,
            # continuous_update=False,  # only trigger when done
            # layout=flayout,
            width=100,
        )

        def ui_gain_on_change(attr, oldvalue, newvalue, parent=self):
            # print(
            #     f"ui_gain_on_change: {repr(oldvalue)},\n {repr(newvalue)}, {repr(type(newvalue))},{parent}"
            # )

            self.yscale = float(newvalue)
            self.update()

        self.ui_gain_bounded_float.on_change("value", ui_gain_on_change)

        self.top_bar_layout = bokeh.layouts.row(
            self.ui_montage_dropdown,
            self.ui_low_freq_filter_dropdown,
            self.ui_high_freq_filter_dropdown,
            self.ui_gain_bounded_float,
            self.ui_notch_option,
        )
        return self.top_bar_layout

    def _limit_time_check(self, candidate):
        if candidate > self.eeghdf_file.duration_seconds:
            return float(self.eeghdf_file.duration_seconds)
        if candidate < 0:
            return 0.0
        return candidate

    def register_bottom_bar_ui(self):
        # self.ui_buttonf = ipywidgets.Button(description="go forward 10s")
        self.ui_buttonf = Button(label="go forward 10s")
        # self.ui_buttonback = ipywidgets.Button(description="go backward 10s")
        self.ui_buttonback = Button(label="go backward 10s")
        # self.ui_buttonf1 = ipywidgets.Button(description="forward 1 s")
        self.ui_buttonf1 = Button(label="forward 1 s")
        # self.ui_buttonback1 = ipywidgets.Button(description="back 1 s")
        self.ui_buttonback1 = Button(label="back 1 s")

        # could put goto input here

        def go_forward(b, parent=self):
            #print(b, parent)
            self.loc_sec = self._limit_time_check(self.loc_sec + 10)
            self.update()

        self.ui_buttonf.on_click(go_forward)

        def go_backward(b):
            self.loc_sec = self._limit_time_check(self.loc_sec - 10)
            self.update()

        self.ui_buttonback.on_click(go_backward)

        def go_forward1(b, parent=self):
            self.loc_sec = self._limit_time_check(self.loc_sec + 1)
            self.update()

        self.ui_buttonf1.on_click(go_forward1)

        def go_backward1(b, parent=self):
            self.loc_sec = self._limit_time_check(self.loc_sec - 1)
            self.update()

        self.ui_buttonback1.on_click(go_backward1)

        #self.ui_current_location = FloatInput...  # keep in sync with jslink?
        def go_to_handler(attr, oldvalue, newvalue, parent=self):
            # print("change:", change)
            self.loc_sec = self._limit_time_check(float(newvalue))
            self.update()

        self.ui_bottom_bar_layout = bokeh.layouts.row(
            self.ui_buttonback,
            self.ui_buttonf,
            self.ui_buttonback1,
            self.ui_buttonf1,
        )
        return self.ui_bottom_bar_layout
        # print('displayed buttons')

    def update_montage(self, montage_name):
        Mv = self.montage_options[montage_name]
        new_montage = Mv(self.ref_labels)
        self.current_montage_instance = new_montage
        self.ch_start = 0
        self.ch_stop = new_montage.shape[0]
        self.update_title()
    def create_layout(self):

        # create figure
        self.x_range = Range1d(start=self.model.map_extent[0],
                               end=self.model.map_extent[2], bounds=None)
        self.y_range = Range1d(start=self.model.map_extent[1],
                               end=self.model.map_extent[3], bounds=None)

        self.fig = Figure(tools='wheel_zoom,pan',
                          x_range=self.x_range,
                          lod_threshold=None,
                          plot_width=self.model.plot_width,
                          plot_height=self.model.plot_height,
                          y_range=self.y_range)

        self.fig.min_border_top = 0
        self.fig.min_border_bottom = 10
        self.fig.min_border_left = 0
        self.fig.min_border_right = 0
        self.fig.axis.visible = False

        self.fig.xgrid.grid_line_color = None
        self.fig.ygrid.grid_line_color = None

        # add tiled basemap
        self.tile_source = WMTSTileSource(url=self.model.basemap)
        self.tile_renderer = TileRenderer(tile_source=self.tile_source)
        self.fig.renderers.append(self.tile_renderer)

        # add datashader layer
        self.image_source = ImageSource(url=self.model.service_url,
                                        extra_url_vars=self.model.shader_url_vars)
        self.image_renderer = DynamicImageRenderer(image_source=self.image_source)
        self.fig.renderers.append(self.image_renderer)

        # add label layer
        self.label_source = WMTSTileSource(url=self.model.labels_url)
        self.label_renderer = TileRenderer(tile_source=self.label_source)
        self.fig.renderers.append(self.label_renderer)

        # Add a hover tool
        hover_layer = HoverLayer()
        hover_layer.field_name = self.model.field_title
        hover_layer.is_categorical = self.model.field in self.model.categorical_fields
        self.fig.renderers.append(hover_layer.renderer)
        self.fig.add_tools(hover_layer.tool)
        self.model.hover_layer = hover_layer

        self.model.legend_side_vbox = VBox()
        self.model.legend_bottom_vbox = VBox()

        # add ui components
        controls = []
        axes_select = Select.create(name='Axes',
                                    options=self.model.axes)
        axes_select.on_change('value', self.on_axes_change)
        controls.append(axes_select)

        self.field_select = Select.create(name='Field', options=self.model.fields)
        self.field_select.on_change('value', self.on_field_change)
        controls.append(self.field_select)

        self.aggregate_select = Select.create(name='Aggregate',
                                              options=self.model.aggregate_functions)
        self.aggregate_select.on_change('value', self.on_aggregate_change)
        controls.append(self.aggregate_select)

        transfer_select = Select.create(name='Transfer Function',
                                        options=self.model.transfer_functions)
        transfer_select.on_change('value', self.on_transfer_function_change)
        controls.append(transfer_select)

        color_ramp_select = Select.create(name='Color Ramp', options=self.model.color_ramps)
        color_ramp_select.on_change('value', self.on_color_ramp_change)
        controls.append(color_ramp_select)

        spread_size_slider = Slider(title="Spread Size (px)", value=0, start=0,
                                    end=10, step=1)
        spread_size_slider.on_change('value', self.on_spread_size_change)
        controls.append(spread_size_slider)

        hover_size_slider = Slider(title="Hover Size (px)", value=8, start=4,
                                   end=30, step=1)
        hover_size_slider.on_change('value', self.on_hover_size_change)
        controls.append(hover_size_slider)

        controls.append(self.model.legend_side_vbox)

        # add map components
        basemap_select = Select.create(name='Basemap', value='Imagery',
                                       options=self.model.basemaps)
        basemap_select.on_change('value', self.on_basemap_change)

        image_opacity_slider = Slider(title="Opacity", value=100, start=0,
                                      end=100, step=1)
        image_opacity_slider.on_change('value', self.on_image_opacity_slider_change)

        basemap_opacity_slider = Slider(title="Basemap Opacity", value=100, start=0,
                                        end=100, step=1)
        basemap_opacity_slider.on_change('value', self.on_basemap_opacity_slider_change)

        show_labels_chk = CheckboxGroup(labels=["Show Labels"], active=[0])
        show_labels_chk.on_click(self.on_labels_change)

        map_controls = [basemap_select, basemap_opacity_slider,
                        image_opacity_slider, show_labels_chk]

        self.controls = VBox(width=200, height=600, children=controls)
        self.map_controls = HBox(width=self.fig.plot_width, children=map_controls)
        self.map_area = VBox(width=self.fig.plot_width, children=[self.map_controls,
                                                                  self.fig,
                                                                  self.model.legend_bottom_vbox])
        self.layout = HBox(width=1366, children=[self.controls, self.map_area])
Ejemplo n.º 15
0
    def create_layout(self):

        # create figure
        self.x_range = Range1d(start=self.model.map_extent[0],
                               end=self.model.map_extent[2], bounds=None)
        self.y_range = Range1d(start=self.model.map_extent[1],
                               end=self.model.map_extent[3], bounds=None)

        self.fig = Figure(tools='wheel_zoom,pan',
                          x_range=self.x_range,
                          lod_threshold=None,
                          plot_width=self.model.plot_width,
                          plot_height=self.model.plot_height,
                          y_range=self.y_range)

        self.fig.min_border_top = 0
        self.fig.min_border_bottom = 10
        self.fig.min_border_left = 0
        self.fig.min_border_right = 0
        self.fig.axis.visible = False

        self.fig.xgrid.grid_line_color = None
        self.fig.ygrid.grid_line_color = None
        
        # add tiled basemap
        self.tile_source = WMTSTileSource(url=self.model.basemap)
        self.tile_renderer = TileRenderer(tile_source=self.tile_source)
        self.fig.renderers.append(self.tile_renderer)

        # add datashader layer
        self.image_source = ImageSource(url=self.model.service_url,
                                        extra_url_vars=self.model.shader_url_vars)
        self.image_renderer = DynamicImageRenderer(image_source=self.image_source)
        self.fig.renderers.append(self.image_renderer)
        
        # add label layer
        self.label_source = WMTSTileSource(url=self.model.labels_url)
        self.label_renderer = TileRenderer(tile_source=self.label_source)
        self.fig.renderers.append(self.label_renderer)

        # Add a hover tool
        self.invisible_square = Square(x='x',
                                       y='y',
                                       fill_color=None,
                                       line_color=None, 
                                       size=self.model.hover_size)

        self.visible_square = Square(x='x',
                                     y='y', 
                                     fill_color='#79DCDE',
                                     fill_alpha=.5,
                                     line_color='#79DCDE', 
                                     line_alpha=1,
                                     size=self.model.hover_size)

        cr = self.fig.add_glyph(self.model.hover_source,
                                self.invisible_square,
                                selection_glyph=self.visible_square,
                                nonselection_glyph=self.invisible_square)

        code = "source.set('selected', cb_data['index']);"
        callback = CustomJS(args={'source': self.model.hover_source}, code=code)
        self.model.hover_tool = HoverTool(tooltips=[(self.model.fields.keys()[0], "@value")],
                                    callback=callback, 
                                    renderers=[cr], 
                                    mode='mouse')
        self.fig.add_tools(self.model.hover_tool)
        self.model.legend_side_vbox = VBox()
        self.model.legend_bottom_vbox = VBox()

        # add ui components
        controls = []
        axes_select = Select.create(name='Axes',
                                    options=self.model.axes)
        axes_select.on_change('value', self.on_axes_change)
        controls.append(axes_select)

        self.field_select = Select.create(name='Field', options=self.model.fields)
        self.field_select.on_change('value', self.on_field_change)
        controls.append(self.field_select)

        self.aggregate_select = Select.create(name='Aggregate',
                                         options=self.model.aggregate_functions)
        self.aggregate_select.on_change('value', self.on_aggregate_change)
        controls.append(self.aggregate_select)

        transfer_select = Select.create(name='Transfer Function',
                                        options=self.model.transfer_functions)
        transfer_select.on_change('value', self.on_transfer_function_change)
        controls.append(transfer_select)

        color_ramp_select = Select.create(name='Color Ramp', options=self.model.color_ramps)
        color_ramp_select.on_change('value', self.on_color_ramp_change)
        controls.append(color_ramp_select)

        spread_size_slider = Slider(title="Spread Size (px)", value=0, start=0,
                                        end=10, step=1)
        spread_size_slider.on_change('value', self.on_spread_size_change)
        controls.append(spread_size_slider)

        hover_size_slider = Slider(title="Hover Size (px)", value=8, start=4,
                                        end=30, step=1)
        hover_size_slider.on_change('value', self.on_hover_size_change)
        controls.append(hover_size_slider)

        controls.append(self.model.legend_side_vbox)

        # add map components
        basemap_select = Select.create(name='Basemap', value='Imagery',
                                       options=self.model.basemaps)
        basemap_select.on_change('value', self.on_basemap_change)

        image_opacity_slider = Slider(title="Opacity", value=100, start=0,
                                      end=100, step=1)
        image_opacity_slider.on_change('value', self.on_image_opacity_slider_change)

        basemap_opacity_slider = Slider(title="Basemap Opacity", value=100, start=0,
                                        end=100, step=1)
        basemap_opacity_slider.on_change('value', self.on_basemap_opacity_slider_change)


        show_labels_chk = CheckboxGroup(labels=["Show Labels"], active=[0])
        show_labels_chk.on_click(self.on_labels_change)

        map_controls = [basemap_select, basemap_opacity_slider,
                        image_opacity_slider, show_labels_chk]

        self.controls = VBox(width=200, height=600, children=controls)
        self.map_controls = HBox(width=self.fig.plot_width, children=map_controls)
        self.map_area = VBox(width=self.fig.plot_width, children=[self.map_controls,
                                                                  self.fig,
                                                                  self.model.legend_bottom_vbox])
        self.layout = HBox(width=1366, children=[self.controls, self.map_area])
        'y6': daily_infected,
        'y7': daily_dead
    }


def update_plots(new):
    switch = checkbox_group.active
    for x in range(0, len(lines)):
        if x in switch:
            lines[x][0].visible = True
            lines[x][1].visible = True
        else:
            lines[x][0].visible = False
            lines[x][1].visible = False

checkbox_group.on_click(update_plots)

## Model 1 selections
select_policy_model_1.on_change('value', callback_select_policy_model_1)
select_object_model_1.on_change('value', callback_select_object_model_1)
select_object_type_model_1.on_change('value', callback_select_object_type_model_1)
select_percentile_model_1.on_change('value', callback_select_percentile_model_1)
select_start_model_1.on_change('value', callback_select_start_model_1)
select_period_model_1.on_change('value', callback_select_period_model_1)
button_model_1.on_click(update_model_1)

## Model 2 selections
select_policy_model_2.on_change('value', callback_select_policy_model_2)
select_object_model_2.on_change('value', callback_select_object_model_2)
select_object_type_model_2.on_change('value', callback_select_object_type_model_2)
select_percentile_model_2.on_change('value', callback_select_percentile_model_2)
Ejemplo n.º 17
0
class Visual:

    def __init__(self, callbackFunc, running):

        self.running = running
        self.callbackFunc = callbackFunc
        # define the sources for plot and map
        self.source = ColumnDataSource(dict(x=[0], sus=[config.param_init_susceptible[config.region]], exp=[config.param_init_exposed[config.region]], inf=[0], sin=[0],
                                        qua=[0], imm=[0], dea=[0], text=[""], mdates = [""]))

        self.sourceJS = ColumnDataSource(dict(text=[]))

        mcallback = CustomJS(args=dict(source=self.source), code="""
            window.data  = source.data

            console.log(source)
        """)
        self.source.js_on_change('change',mcallback)

        self.tools = 'pan, box_zoom, wheel_zoom, reset'
        self.plot_options = dict(plot_width=800, plot_height=600, tools = [self.tools])
        self.updateValue = True
        self.pAll = self.definePlot(self.source)
        self.doc = curdoc()
        self.layout()
        self.prev_y1 = 0

        # initialize the widgets' values
        self.region_names = config.region_names

        self.init_exposed.value = config.param_init_exposed[config.region]
        self.sus_to_exp_slider.value = config.param_beta_exp[config.region]
        self.param_qr_slider.value = config.param_qr[config.region]
        self.param_sir.value = config.param_sir[config.region]
        self.param_hosp_capacity.value = config.param_hosp_capacity[config.region]
        self.param_gamma_mor1.value = config.param_gamma_mor1[config.region]
        self.param_gamma_mor2.value = config.param_gamma_mor2[config.region]
        self.param_gamma_im.value = config.param_gamma_im[config.region]
        self.param_eps_exp.value = config.param_eps_exp[config.region]
        self.param_eps_qua.value = config.param_eps_qua[config.region]
        self.param_eps_sev.value = config.param_eps_sev[config.region]

        self.start_date = date.today()
        # transition_matrix checkbox
        self.box1 = list(range(0, 17))
        self.box2 = list(range(0, 17))
        self.box3 = list(range(0, 17))

    def definePlot(self, source):

        # format the text of the plot
        p1 = figure(**self.plot_options, title='Covid Simulation',  toolbar_location='above')
        p1.yaxis.axis_label = 'Number of people'
        p1.xaxis.axis_label = 'Simulation time (days)'
        p1.xaxis[0].formatter = PrintfTickFormatter(format="%9.0f")
        p1.yaxis[0].formatter = PrintfTickFormatter(format="%9.0f")
        p1.xaxis.major_label_text_font_size = "10pt"
        p1.yaxis.major_label_text_font_size = "10pt"

        p2 = figure(**self.plot_options, title='Number of Susceptible people', toolbar_location='above')
        p2.yaxis.axis_label = 'Number of people'
        p2.xaxis.axis_label = 'Simulation time (days)'
        p2.xaxis[0].formatter = PrintfTickFormatter(format="%9.0f")
        p2.yaxis[0].formatter = PrintfTickFormatter(format="%9.0f")
        p2.xaxis.major_label_text_font_size = "10pt"
        p2.yaxis.major_label_text_font_size = "10pt"

        # format the plot line
        r0 = p2.line(source =source, x='x', y='sus', color='cyan', line_width=1,line_dash='dashed', legend='Susceptible')
        r1 = p2.circle(source=source, x='x', y='sus', color='cyan', size=10, legend='Susceptible')

        r2 = p1.line(source=source, x='x', y='exp',color='gold',line_width=1,line_dash='dotted', legend='Exposed')
        r3 = p1.circle(source=source, x='x', y='exp',color='gold',size=10, legend='Exposed')

        r4 = p1.line(source=source, x='x', y='inf',color='white',line_width=1,line_dash='dotted', legend='Infected')
        r5 = p1.circle(source=source, x='x', y='inf',color='white',size=10, legend='Infected')

        r6 = p1.line(source=source, x='x', y='sin',color='purple',line_width=1,line_dash='dotted', legend='Severe Infected')
        r7 = p1.circle(source=source, x='x', y='sin',color='purple',size=10, legend='Severe Infected')

        r8 = p1.line(source=source, x='x', y='qua',color='lime',line_width=1,line_dash='dotted', legend='Quarantined')
        r9 = p1.circle(source=source, x='x', y='qua',color='lime',size=10, legend='Quarantined')

        r10 = p1.line(source=source, x='x', y='imm',color='deepskyblue',line_width=1,line_dash='dotted', legend='Immunized')
        r11 = p1.circle(source=source, x='x', y='imm',color='deepskyblue',size=10, legend='Immunized')

        r12 = p1.line(source=source, x='x', y='dea',color='red',line_width=1,line_dash='dotted', legend='Dead')
        r13 = p1.circle(source=source, x='x', y='dea',color='red',size=10, legend='Dead')

        legend = Legend(items=[
                                ('Exposed', [r2, r3]),
                                ('Infected', [r4, r5]),
                                ('Severe Infected', [r6, r7]),
                                ('Quarantined', [r8, r9]),
                                ('Immunized', [r10, r11]),
                                ('Dead', [r12, r13])])

        # legends
        p1.legend.click_policy = 'hide'
        p1.background_fill_color = "black"
        p1.background_fill_alpha = 0.8
        p1.legend.location = "top_left"
        p1.legend.background_fill_color = "cyan"
        p1.legend.background_fill_alpha = 0.5
        p1.outline_line_width = 7
        p1.outline_line_alpha = 0.9
        p1.outline_line_color = "black"

        p2.legend.click_policy = 'hide'
        p2.background_fill_color = "black"
        p2.background_fill_alpha = 0.8
        p2.legend.location = "top_left"
        p2.legend.background_fill_color = "cyan"
        p2.legend.background_fill_alpha = 0.5
        p2.outline_line_width = 7
        p2.outline_line_alpha = 0.9
        p2.outline_line_color = "black"

        kz_map_tag = Div(text="""<div id="svg_holder" style="float:left;"> <svg width="780" height="530" id="statesvg"></svg> <div id="tooltip"></div>   </div>""", width=960, height=600)
        kz_map_row = row(kz_map_tag)
        pAll = row(p1, kz_map_row)

        return pAll

    #@gen.coroutine
    def update(self, change_view):

        region_states = dict()
        # obtain the state values
        new_nodes_all = config.new_plot_all
        # construct the array for plotting the states
        newx = [0]
        state_inf = [0]
        state_sus=[config.param_init_susceptible[config.region]]
        state_exp = [config.param_init_exposed[config.region]]
        state_sin = [0]
        state_qua = [0]
        state_imm = [0]
        state_dea = [0]

        tmp_state_inf = [0]
        tmp_state_sus=[config.param_init_susceptible[config.region]]
        tmp_state_exp = [config.param_init_exposed[config.region]]
        tmp_state_sin = [0]
        tmp_state_qua = [0]
        tmp_state_imm = [0]
        tmp_state_dea = [0]

        start_date = self.start_date
        cur_date = (start_date + timedelta(config.counter_func)).strftime("%d %b %Y")
        start_date = self.start_date.strftime("%d %b %Y")

        # for graph
        if new_nodes_all != [] and config.region != 17:
            for i in range(len(config.new_plot_all)):
                state_inf.append(new_nodes_all[i][:, config.region, 0][-1] + new_nodes_all[i][:, config.region, 7][-1])
                state_exp.append(new_nodes_all[i][:, config.region, 1][-1])
                state_sin.append(new_nodes_all[i][:, config.region, 2][-1])
                state_qua.append(new_nodes_all[i][:, config.region, 3][-1])
                state_imm.append(new_nodes_all[i][:, config.region, 4][-1])
                state_sus.append(new_nodes_all[i][:, config.region, 5][-1])
                state_dea.append(new_nodes_all[i][:, config.region, 6][-1])
                newx = config.param_sim_len[0]*(np.arange(config.counter_func+1))

                # for map
                regions_ids = [ lregion for lregion in range(17)]
                for region in regions_ids:
                    if region in region_states:
                        region_states[region]["tmp_state_inf"].append(new_nodes_all[i][:, region, 0][-1]+ new_nodes_all[i][:, region, 7][-1])
                        region_states[region]["tmp_state_sin"].append(new_nodes_all[i][:, region, 2][-1])
                        region_states[region]["tmp_state_exp"].append(new_nodes_all[i][:, region, 1][-1])
                        region_states[region]["tmp_state_qua"].append(new_nodes_all[i][:, region, 3][-1])
                        region_states[region]["tmp_state_imm"].append(new_nodes_all[i][:, region, 4][-1])
                        region_states[region]["tmp_state_sus"].append(new_nodes_all[i][:, region, 5][-1])
                        region_states[region]["tmp_state_dea"].append(new_nodes_all[i][:, region, 6][-1])
                    else:
                        tmp_data = {
                            "tmp_state_inf": [],
                            "tmp_state_sin": [],
                            "tmp_state_exp": [],
                            "tmp_state_qua": [],
                            "tmp_state_imm": [],
                            "tmp_state_sus": [],
                            "tmp_state_dea": []
                            }

                        tmp_data["tmp_state_inf"].append(new_nodes_all[i][:, region, 0][-1]+ new_nodes_all[i][:, region, 7][-1])
                        tmp_data["tmp_state_sin"].append(new_nodes_all[i][:, region, 2][-1])
                        tmp_data["tmp_state_exp"].append(new_nodes_all[i][:, region, 1][-1])
                        tmp_data["tmp_state_qua"].append(new_nodes_all[i][:, region, 3][-1])
                        tmp_data["tmp_state_imm"].append(new_nodes_all[i][:, region, 4][-1])
                        tmp_data["tmp_state_sus"].append(new_nodes_all[i][:, region, 5][-1])
                        tmp_data["tmp_state_dea"].append(new_nodes_all[i][:, region, 6][-1])

                        region_states[region] = tmp_data


        elif new_nodes_all != [] and config.region == 17:
            for i in range(len(config.new_plot_all)):

                state_inf.append(sum(new_nodes_all[i][:, :, 0][-1]) + sum(new_nodes_all[i][:, :, 7][-1]))
                state_exp.append(sum(new_nodes_all[i][:, :, 1][-1]))
                state_sin.append(sum(new_nodes_all[i][:, :, 2][-1]))
                state_qua.append(sum(new_nodes_all[i][:, :, 3][-1]))
                state_imm.append(sum(new_nodes_all[i][:, :, 4][-1]))
                state_sus.append(sum(new_nodes_all[i][:, :, 5][-1]))
                state_dea.append(sum(new_nodes_all[i][:, :, 6][-1]))
                newx = config.param_sim_len[0]*(np.arange(config.counter_func+1))

                regions_ids = [ lregion for lregion in range(17)]
                for region in regions_ids:
                    if str(region) in region_states and type(region_states[region]) is dict:
                        region_states[region]["tmp_state_inf"].append(new_nodes_all[i][:, region, 0][-1] + new_nodes_all[i][:, region, 7][-1])
                        region_states[region]["tmp_state_sin"].append(new_nodes_all[i][:, region, 2][-1])
                        region_states[region]["tmp_state_exp"].append(new_nodes_all[i][:, region, 1][-1])
                        region_states[region]["tmp_state_qua"].append(new_nodes_all[i][:, region, 3][-1])
                        region_states[region]["tmp_state_imm"].append(new_nodes_all[i][:, region, 4][-1])
                        region_states[region]["tmp_state_sus"].append(new_nodes_all[i][:, region, 5][-1])
                        region_states[region]["tmp_state_dea"].append(new_nodes_all[i][:, region, 6][-1])
                    else:
                        tmp_data = {
                            "tmp_state_inf": [],
                            "tmp_state_sin": [],
                            "tmp_state_exp": [],
                            "tmp_state_qua": [],
                            "tmp_state_imm": [],
                            "tmp_state_sus": [],
                            "tmp_state_dea": []
                            }

                        tmp_data["tmp_state_inf"].append(new_nodes_all[i][:, region, 0][-1] + new_nodes_all[i][:, region, 7][-1])
                        tmp_data["tmp_state_sin"].append(new_nodes_all[i][:, region, 2][-1])
                        tmp_data["tmp_state_exp"].append(new_nodes_all[i][:, region, 1][-1])
                        tmp_data["tmp_state_qua"].append(new_nodes_all[i][:, region, 3][-1])
                        tmp_data["tmp_state_imm"].append(new_nodes_all[i][:, region, 4][-1])
                        tmp_data["tmp_state_sus"].append(new_nodes_all[i][:, region, 5][-1])
                        tmp_data["tmp_state_dea"].append(new_nodes_all[i][:, region, 6][-1])

                        region_states[region] = tmp_data

        str_data = json.dumps(region_states, ensure_ascii=False)
        str_mdates = json.dumps([start_date, cur_date],ensure_ascii=False)
        new_data = dict(x=newx, sus=state_sus, exp=state_exp, inf=state_inf, sin=state_sin,
                    qua=state_qua, imm=state_imm, dea=state_dea, text=[str_data]*len(state_imm), mdates=[str_mdates]*len(state_imm))

        self.data1 = dict(
                        c0=[(config.transition_matrix[0,i]) for i in range(0,17)],
                        c1=[(config.transition_matrix[1,i]) for i in range(0,17)],
                        c2=[(config.transition_matrix[2,i]) for i in range(0,17)],
                        c3=[(config.transition_matrix[3,i]) for i in range(0,17)],
                        c4=[(config.transition_matrix[4,i]) for i in range(0,17)],
                        c5=[(config.transition_matrix[5,i]) for i in range(0,17)],
                        c6=[(config.transition_matrix[6,i]) for i in range(0,17)],
                        c7=[(config.transition_matrix[7,i]) for i in range(0,17)],
                        c8=[(config.transition_matrix[8,i]) for i in range(0,17)],
                        c9=[(config.transition_matrix[9,i]) for i in range(0,17)],
                        c10=[(config.transition_matrix[10,i]) for i in range(0,17)],
                        c11=[(config.transition_matrix[11,i]) for i in range(0,17)],
                        c12=[(config.transition_matrix[12,i]) for i in range(0,17)],
                        c13=[(config.transition_matrix[13,i]) for i in range(0,17)],
                        c14=[(config.transition_matrix[14,i]) for i in range(0,17)],
                        c15=[(config.transition_matrix[15,i]) for i in range(0,17)],
                        c16=[(config.transition_matrix[16,i]) for i in range(0,17)],)

        self.source.data.update(new_data)
        self.sourceT.data.update(self.data1)
        self.data_tableT.update()

    def SelectRegionHandler(self, attr, old, new):
        regions = config.region_names

        for i, region in enumerate(regions):
            if new == region:
                config.region = i
                break
        self.update(True)
        self.slider_update_initial_val(self, old, new)

    def update_transition_matrix(self):
        nodes_num = 17
        self.param_transition_box = []
        self.param_transition_box.append(self.box1)
        self.param_transition_box.append(self.box2)
        self.param_transition_box.append(self.box3)
        tr_boxes = self.param_transition_box

        param_transition_table = np.zeros((17,3))
        for i, way in enumerate(tr_boxes): # air 0 rail 1 road 2
            for j, node in enumerate(way):
                status = int(node)
                param_transition_table[status, i] = 1
        # load transition matrix
        transition_railway = config.transition_railway.copy()
        transition_airway = config.transition_airway.copy()
        transition_roadway = config.transition_roadway.copy()

        tr_table = [transition_airway, transition_railway, transition_roadway]

        for j, tr in enumerate(tr_table):
            for i in range(17):
                tr[i, :] = tr[i, :]*param_transition_table[i,j]
                tr[:, i] = tr[i, :]*param_transition_table[i,j]

        transition_matrix = 0.5*(transition_railway + transition_airway + transition_roadway)*(config.param_transition_scale[0] )

        for i in range(nodes_num):
            for j in range(nodes_num):
                if transition_matrix[i,j] < 0.01:
                    transition_matrix[i,j] = config.transition_matrix_init[i,j]*config.param_transition_leakage[0] # base data is for 24 days, tran_dt = 1/2

        transition_matrix = transition_matrix.astype(int)

        config.param_transition_table = copy(param_transition_table)
        config.transition_matrix = copy(transition_matrix)
        self.update(False)

    def reset_click(self):
        # reset the params
        if config.flag_sim == 0:
            config.new_plot_all = []
            config.counter_func = 0
            config.run_iteration=False

            config.last_state_list = []
            config.nodes_old = []
            config.new_plot = []
            config.is_loaded = False

            self.slider_update_reset(self, 0, 0)
            self.update(False)
            print('[INFO] Resetting the simulation parameters ..')

    def load_click(self):
        # load the previous results
        self.reset_click()
        directory = 'results' + '/' +  config.param_save_file
        fname = directory + '/' + 'Kazakhstan' + '.csv'

        if os.path.isfile(fname):
            with open(fname,"r") as f:
                reader = csv.reader(f,delimiter = ",")
                data = list(reader)
                row_n = len(data)

            # reset
            new_plot = np.zeros((row_n, 17, 8))
            config.box_time = np.zeros((17, 3, row_n))
            config.arr_for_save = np.zeros((row_n, 17, 15))

            # fill the new_plot
            for j in range(config.nodes_num):
                filename =  directory + '/' + config.region_names[j] + '.csv'
                with open(filename, newline='') as csvfile:
                    csvreader = csv.reader(csvfile, delimiter=',')
                    count_row = 0
                    for row in csvreader:
                        if count_row > 0:
                            # states
                            data_states = [(float(item)) for item in row[2:10]]
                            data_states = np.array(data_states)
                            new_plot[count_row,j,:] = data_states[:]
                            # transition
                            data_box = [(float(item)) for item in row[25:28]]
                            data_box = np.array(data_box)
                            config.box_time[j,:, count_row] = data_box[:]
                            # parameters
                            data_arr = [(float(item)) for item in row[10:25]]
                            data_arr = np.array(data_arr)
                            config.arr_for_save[count_row,j,:] = data_arr[:]
                        count_row += 1
                        if count_row == (2):
                            config.last_date = row[1]

            config.counter_load = count_row-1
            config.new_plot = new_plot

            # restore parameters
            config.param_init_exposed = config.arr_for_save[config.counter_func-1,:,0]
            config.param_beta_exp = config.arr_for_save[config.counter_func-1,:,1]
            config.param_qr  = config.arr_for_save[config.counter_func-1,:,2]
            config.param_sir  = config.arr_for_save[config.counter_func-1,:,3]
            config.param_hosp_capacity = config.arr_for_save[config.counter_func-1,:,4]

            config.param_gamma_mor1 = config.arr_for_save[config.counter_func-1,:,5]
            config.param_gamma_mor2 = config.arr_for_save[config.counter_func-1,:,6]
            config.param_gamma_im = config.arr_for_save[config.counter_func-1,:,7]

            config.param_eps_exp = config.arr_for_save[config.counter_func-1,:,8]
            config.param_eps_qua = config.arr_for_save[config.counter_func-1,:,9]
            config.param_eps_sev  = config.arr_for_save[config.counter_func-1,:,10]

            config.param_t_exp = config.arr_for_save[config.counter_func-1,:,11]
            config.param_t_inf = config.arr_for_save[config.counter_func-1,:,12]

            config.param_transition_leakage = config.arr_for_save[config.counter_func-1,:,13]
            config.param_transition_scale = config.arr_for_save[config.counter_func-1,:,14]
            config.param_transition_table = copy(config.box_time[:,:,config.counter_func-1])

            l1 = [i for i, x in enumerate(list(config.param_transition_table[:,0])) if x > 0]
            l2 = [i for i, x in enumerate(list(config.param_transition_table[:,1])) if x > 0]
            l3 = [i for i, x in enumerate(list(config.param_transition_table[:,2])) if x > 0]

            self.checkbox_group1.active = l1
            self.checkbox_group2.active = l2
            self.checkbox_group3.active = l3

            self.slider_update_initial_val(0,0,0)

            filename =  directory + '/' + 'states_x' + '.csv'
            with open(filename,"r") as f:
                reader = csv.reader(f,delimiter = ",")
                r_count = 0
                for row in reader:
                    temp = np.array(row)
                    config.last_state_list.append(temp)
                    r_count =+ 1

            config.is_loaded = True
            # plot graph
            if config.flag_sim == 0:
                config.load_iteration=True

            self.datepicker.value = config.last_date

        else:
            print('[INFO] No such folder to load the results.')

    def run_click(self):
        if config.flag_sim == 0:
            self.update_transition_matrix()
            config.run_iteration=True
            self.update(False)

    def save_file_click(self):

        if config.flag_sim == 0:
            # points*nodes*states
            info = config.header_file_csv
            info2 = config.header_file_csv2

            directory = 'results' + '/' +  config.param_save_file
            if not os.path.exists(directory):
                os.makedirs(directory)

            box_corr = config.box_time
            if config.new_plot_all:
                for j in range(17):
                    filename =  directory + '/' + self.region_names[j] + '.csv'
                    with open(filename, 'w', newline='') as csvfile:
                        data_writer = csv.writer(csvfile, delimiter=',', escapechar=' ', quoting=csv.QUOTE_NONE)
                        #points*nodes*states
                        data_writer.writerow([info])
                        for iter in range(1,config.counter_func+1):
                            one_arr = config.new_plot_all[iter-1] #
                            one_arr_node = one_arr[-1,j,:].astype(int)
                            m = 17

                            curr_date = self.start_date + timedelta(iter-1)
                            one_arr_node = np.append(one_arr_node, (config.arr_for_save[iter,j,0], config.arr_for_save[iter,j,1], config.arr_for_save[iter,j,2],
                                                     config.arr_for_save[iter,j,3], config.arr_for_save[iter,j,4], config.arr_for_save[iter,j,5], config.arr_for_save[iter,j,6],
                                                     config.arr_for_save[iter,j,7], config.arr_for_save[iter,j,8], config.arr_for_save[iter,j,9], config.arr_for_save[iter,j,10],
                                                     config.arr_for_save[iter,j,11], config.arr_for_save[iter,j,12], config.arr_for_save[iter,j,13], config.arr_for_save[iter,j,14],
                                                     box_corr[j,0,iter],box_corr[j,1,iter],box_corr[j,2,iter]))

                            one_arr_node_list = list(one_arr_node)
                            alist = [iter] + [curr_date] + one_arr_node_list
                            data_writer.writerows([alist])

                filename =  directory + '/' + 'Kazakhstan' + '.csv'
                with open(filename, 'w', newline='') as csvfile:
                    data_writer = csv.writer(csvfile, delimiter=',',  escapechar=' ', quoting=csv.QUOTE_NONE)
                    #points*nodes*states
                    data_writer.writerow([info2])
                    for iter in range(1, config.counter_func+1):
                        if config.new_plot_all:
                            one_arr = config.new_plot_all[iter-1]
                            one_arr_node = one_arr[-1,:,:].astype(int)
                            one_arr_node_sum = one_arr_node.sum(axis=0)
                            one_arr_node_list = list(one_arr_node_sum)
                            curr_date = self.start_date + timedelta(iter-1)
                            alist = [iter] + [curr_date] + one_arr_node_list
                            data_writer.writerows([alist])

                # last state save
                filename =  directory + '/' + 'states_x' + '.csv'
                with open(filename, 'w', newline='') as csvfile:
                    data_writer = csv.writer(csvfile, delimiter=',', escapechar=' ', quoting=csv.QUOTE_NONE)
                    nodes_new_iter = copy(config.nodes_old)
                    for index, node in enumerate(nodes_new_iter):
                       node.states_x = nodes_new_iter[index].states_x
                       st_t = copy(node.states_x)
                       st_t = list(st_t)
                       data_writer.writerow(st_t)

                print('[INFO] Saving results to .csv format ..')
            else:
                print('[INFO] No data to save.')

    def slider_update_initial_val(self, attr, old, new):

        self.init_exposed.value = config.param_init_exposed[config.region]
        self.sus_to_exp_slider.value = config.param_beta_exp[config.region]
        self.param_qr_slider.value = config.param_qr[config.region]
        self.param_sir.value = config.param_sir[config.region]
        self.param_hosp_capacity.value = config.param_hosp_capacity[config.region]
        self.param_gamma_mor1.value = config.param_gamma_mor1[config.region]
        self.param_gamma_mor2.value = config.param_gamma_mor2[config.region]
        self.param_gamma_im.value = config.param_gamma_im[config.region]
        self.param_eps_exp.value = config.param_eps_exp[config.region]
        self.param_eps_qua.value = config.param_eps_qua[config.region]
        self.param_eps_sev.value = config.param_eps_sev[config.region]
        self.param_t_exp.value = config.param_t_exp[0]
        self.param_t_inf.value = config.param_t_inf[0]
        self.param_tr_leakage.value = config.param_transition_leakage[0]
        self.param_tr_scale.value = config.param_transition_scale[0]

    def slider_update_reset(self, attr, old, new):
        nodes_num =17

        config.param_init_exposed = 0*np.ones(nodes_num)
        config.param_beta_exp = 30.0*np.ones(nodes_num)
        config.param_qr = 2.0*np.ones(nodes_num)
        config.param_sir = 0.35*np.ones(nodes_num)
        config.param_hosp_capacity = np.array((280,2395,895,600,650,250,725,100,885,425,1670,300,465,1420,1505,380,300))
        config.param_gamma_mor1 = 7.0*np.ones(nodes_num)
        config.param_gamma_mor2= 11.0*np.ones(nodes_num)
        config.param_gamma_im = 90.0*np.ones(nodes_num)
        config.param_eps_exp= 100.0*np.ones(nodes_num)
        config.param_eps_qua = 20.0*np.ones(nodes_num)
        config.param_eps_sev = 20.0*np.ones(nodes_num)
        config.param_t_exp = 5*np.ones(nodes_num)
        config.param_t_inf = 14*np.ones(nodes_num)
        config.param_transition_leakage = 0.0*np.ones(nodes_num)
        config.param_transition_scale = 1.0*np.ones(nodes_num)

        self.slider_update_initial_val(self,old, new)

        self.checkbox_group1.active = list(range(0, 17))
        self.checkbox_group2.active = list(range(0, 17))
        self.checkbox_group3.active = list(range(0, 17))
        self.update_transition_matrix()

        config.box_time = copy(config.param_transition_table)
        config.arr_for_save = np.dstack((config.param_init_exposed, config.param_beta_exp, config.param_qr, config.param_sir, config.param_hosp_capacity,
                                config.param_gamma_mor1, config.param_gamma_mor2, config.param_gamma_im, config.param_eps_exp,
                                config.param_eps_qua, config.param_eps_sev, config.param_t_exp, config.param_t_inf, config.param_transition_leakage,
                                 config.param_transition_scale))

        self.datepicker.value = datetime.today()
        self.start_date = datetime.today()

    def handler_beta_exp(self, attr, old, new):
        config.param_beta_exp[config.region]=new

    def handler_param_qr(self, attr, old, new):
        config.param_qr[config.region]=new

    def handler_param_sir(self, attr, old, new):
        config.param_sir[config.region]=new

    def handler_param_eps_exp(self, attr, old, new):
        config.param_eps_exp[config.region]=new

    def handler_param_eps_qua(self, attr, old, new):
        config.param_eps_qua[config.region]=new

    def handler_param_eps_sev(self, attr, old, new):
        config.param_eps_sev[config.region]=new

    def handler_param_hosp_capacity(self, attr, old, new):
        config.param_hosp_capacity[config.region]=new

    def handler_param_gamma_mor1(self, attr, old, new):
        config.param_gamma_mor1[config.region]=new

    def handler_param_gamma_mor2(self, attr, old, new):
        config.param_gamma_mor2[config.region]=new

    def handler_param_gamma_im(self, attr, old, new):
        config.param_gamma_im[config.region]=new

    def handler_param_sim_len(self, attr, old, new):
        config.loop_num=new

    def handler_param_t_exp(self, attr, old, new):
        if config.counter_func < 1:
            config.param_t_exp[0]=new
        else :
            self.slider_update_initial_val(self, old, new)

    def handler_param_t_inf(self, attr, old, new):
        if config.counter_func < 1:
            config.param_t_inf[0]=new
        else:
            self.slider_update_initial_val(self, old, new)

    def handler_init_exposed(self, attr, old, new):
        if config.counter_func < 1:
            config.param_init_exposed[config.region]=new
            self.update(False)
        else:
            self.slider_update_initial_val(self, old, new)

    def handler_param_tr_scale(self, attr, old, new):
        config.param_transition_scale=new*np.ones(config.nodes_num)
        self.update_transition_matrix()

    def handler_param_tr_leakage(self, attr, old, new):
        config.param_transition_leakage=new*np.ones(config.nodes_num)
        self.update_transition_matrix()

    def handler_checkbox_group1(self, new):
        self.box1 = new
        self.update_transition_matrix()

    def handler_checkbox_group2(self, new):
        self.box2 = new
        self.update_transition_matrix()

    def handler_checkbox_group3(self, new):
        self.box3 = new
        self.update_transition_matrix()

    def handler_param_save_file(self, attr, old, new):
        config.param_save_file= str(new)

    def get_date(self, attr, old, new):
        self.start_date = new

    def layout(self):

        # define text font, colors
        self.text1 = Div(text="""<h1 style="color:blue">COVID-19 Simulator for Kazakhstan</h1>""", width=500, height=50)
        self.text4 = Div(text="""<h1 style="color:blue"> </h1>""", width=900, height=50)

        self.text2 =  Div(text="<b>Select parameters for each region</b>", style={'font-size': '150%', 'color': 'green'},width=350)
        self.text3 =  Div(text="<b>Select global parameters </b>", style={'font-size': '150%', 'color': 'green'}   )
        self.text5 =  Div(text="<b>Change transition matrix</b>", style={'font-size': '150%', 'color': 'green'})

        # select region - dropdown menu
        regions = config.region_names

        initial_region = 'Almaty'
        region_selection = Select(value=initial_region, title=' ', options=regions, width=250, height=15)
        region_selection.on_change('value', self.SelectRegionHandler)

        # select parameters - sliders
        self.sus_to_exp_slider = Slider(start=0.0,end=50.0,step=0.5,value=config.param_beta_exp[config.region], title='Susceptible to Exposed transition constant (%)')
        self.sus_to_exp_slider.on_change('value', self.handler_beta_exp)

        self.param_qr_slider = Slider(start=0.0,end=25.0,step=0.25,value=config.param_qr[config.region], title='Daily Quarantine rate of the Exposed (%)')
        self.param_qr_slider.on_change('value', self.handler_param_qr)

        self.param_sir = Slider(start=0.0,end=5.0,step=0.05,value=config.param_sir[config.region], title='Daily Infected to Severe Infected transition rate (%)')
        self.param_sir.on_change('value', self.handler_param_sir)

        self.param_eps_exp = Slider(start=0,end=100,step=1.0,value=config.param_eps_exp[config.region], title='Disease transmission rate of Exposed compared to Infected (%)')
        self.param_eps_exp.on_change('value', self.handler_param_eps_exp)

        self.param_eps_qua = Slider(start=0,end=100,step=1.0,value=config.param_eps_qua[config.region], title='Disease transmission rate of Quarantined compared to Infected (%)')
        self.param_eps_qua.on_change('value', self.handler_param_eps_qua)

        self.param_eps_sev = Slider(start=0,end=100,step=1.0,value=config.param_eps_sev[config.region], title='Disease transmission rate of Severe Infected compared to Infected (%)')
        self.param_eps_sev.on_change('value', self.handler_param_eps_sev)

        self.param_hosp_capacity = Slider(start=0,end=10000,step=1,value=config.param_hosp_capacity[config.region], title='Hospital Capacity')
        self.param_hosp_capacity.on_change('value', self.handler_param_hosp_capacity)

        self.param_gamma_mor1 = Slider(start=0,end=100,step=1.0,value=config.param_gamma_mor1[config.region], title='Severe Infected to Dead transition probability (%)')
        self.param_gamma_mor1.on_change('value', self.handler_param_gamma_mor1)

        self.param_gamma_mor2 = Slider(start=0,end=100,step=1,value=config.param_gamma_mor2[config.region], title='Severe Infected to Dead transition probability (Hospital Cap. Exceeded) (%)')
        self.param_gamma_mor2.on_change('value', self.handler_param_gamma_mor2)

        self.param_gamma_im = Slider(start=0,end=100,step=1,value=config.param_gamma_im[config.region], title='Infected to Recovery Immunized transition probability (%)')
        self.param_gamma_im.on_change('value', self.handler_param_gamma_im)

        self.param_sim_len = Slider(start=1,end=100,step=1,value=config.loop_num, title='Length of simulation (Days)')
        self.param_sim_len.on_change('value', self.handler_param_sim_len)

        self.param_t_exp = Slider(start=1,end=20,step=1,value=config.param_t_exp[0], title='Incubation period (Days) ')
        self.param_t_exp.on_change('value', self.handler_param_t_exp)

        self.param_t_inf = Slider(start=1,end=20,step=1,value=config.param_t_inf[0], title=' Infection  period (Days) ')
        self.param_t_inf.on_change('value', self.handler_param_t_inf)

        self.init_exposed = Slider(start=0,end=100,step=1,value=config.param_init_exposed[config.region], title='Initial Exposed')
        self.init_exposed.on_change('value', self.handler_init_exposed)

        self.param_tr_scale = Slider(start=0.0,end=1,step=0.01,value=config.param_transition_scale[0], title='Traffic ratio')
        self.param_tr_scale.on_change('value', self.handler_param_tr_scale)

        self.param_tr_leakage = Slider(start=0.0,end=1,step=0.01,value=config.param_transition_leakage[0], title='Leakage ratio')
        self.param_tr_leakage.on_change('value', self.handler_param_tr_leakage)

        dumdiv = Div(text='',width=10)
        dumdiv2= Div(text='',width=10)
        dumdiv3= Div(text='',width=200)
        dumdiv3ss= Div(text='',width=120)

        # Buttons
        reset_button = Button(label = 'Reset data', button_type='primary', background = "red")
        save_button_result = Button(label='Save current plot to .csv in directory results/', button_type='primary')
        run_button = Button(label='Run the simulation',button_type='primary')
        load_button = Button(label='Load data from directory results/', button_type='primary')

        run_button.on_click(self.run_click)
        reset_button.on_click(self.reset_click)
        save_button_result.on_click(self.save_file_click)
        load_button.on_click(self.load_click)

        # input folder name
        text_save = TextInput(value="foldername", title="")
        text_save.on_change('value', self.handler_param_save_file)

        # transition matrix - checkbox
        div_cb1 = Div(text = 'Airways', width = 150)
        div_cb2 = Div(text = 'Railways', width = 150)
        div_cb3 = Div(text = 'Highways', width = 150)

        self.checkbox_group1 = CheckboxGroup(labels=regions, active = list(range(0, 17)))
        self.checkbox_group2 = CheckboxGroup(labels=regions, active= list(range(0, 17)))
        self.checkbox_group3 = CheckboxGroup(labels=regions, active= list(range(0, 17)))

        self.checkbox_group1.on_click(self.handler_checkbox_group1)
        self.checkbox_group2.on_click(self.handler_checkbox_group2)
        self.checkbox_group3.on_click(self.handler_checkbox_group3)

        # transition matrix - table
        self.data1 = dict(
                        c00 =  regions,
                        c0= [(config.transition_matrix[0,i]) for i in range(0,17)],
                        c1= [(config.transition_matrix[1,i]) for i in range(0,17)],
                        c2= [(config.transition_matrix[2,i]) for i in range(0,17)],
                        c3=[(config.transition_matrix[3,i]) for i in range(0,17)],
                        c4=[(config.transition_matrix[4,i]) for i in range(0,17)],
                        c5=[(config.transition_matrix[5,i]) for i in range(0,17)],
                        c6=[(config.transition_matrix[6,i]) for i in range(0,17)],
                        c7=[(config.transition_matrix[7,i]) for i in range(0,17)],
                        c8=[(config.transition_matrix[8,i]) for i in range(0,17)],
                        c9=[(config.transition_matrix[9,i]) for i in range(0,17)],
                        c10=[(config.transition_matrix[10,i]) for i in range(0,17)],
                        c11=[(config.transition_matrix[11,i]) for i in range(0,17)],
                        c12=[(config.transition_matrix[12,i]) for i in range(0,17)],
                        c13=[(config.transition_matrix[13,i]) for i in range(0,17)],
                        c14=[(config.transition_matrix[14,i]) for i in range(0,17)],
                        c15=[(config.transition_matrix[15,i]) for i in range(0,17)],
                        c16=[(config.transition_matrix[16,i]) for i in range(0,17)],)

        columns = [
                    TableColumn(field="c00", title=" ",),
                    TableColumn(field="c0", title="Almaty",),
                    TableColumn(field="c1", title="Almaty Qalasy",),
                    TableColumn(field="c2", title="Aqmola",),
                    TableColumn(field="c3", title="Aqtobe",),
                    TableColumn(field="c4", title="Atyrau",),
                    TableColumn(field="c5", title="West Kazakhstan",),
                    TableColumn(field="c6", title="Jambyl",),
                    TableColumn(field="c7", title="Mangystau",),
                    TableColumn(field="c8", title="Nur-Sultan",),
                    TableColumn(field="c9", title="Pavlodar",),
                    TableColumn(field="c10", title="Qaragandy",),
                    TableColumn(field="c11", title="Qostanai",),
                    TableColumn(field="c12", title="Qyzylorda",),
                    TableColumn(field="c13", title="East Kazakhstan",),
                    TableColumn(field="c14", title="Shymkent",),
                    TableColumn(field="c15", title="North Kazakhstan",),
                    TableColumn(field="c16", title="Turkistan",),]

        self.sourceT = ColumnDataSource(self.data1)
        self.data_tableT = DataTable(source=self.sourceT, columns=columns, width=1750, height=500, sortable = False)

        # select start date - calendar
        self.datepicker = DatePicker(title="Starting date of simulation", min_date=datetime(2015,11,1), value=datetime.today())
        self.datepicker.on_change('value',self.get_date)

        # place the widgets on the layout

        sliders_1 = column(self.init_exposed, self.sus_to_exp_slider, self.param_qr_slider, self.param_sir)
        sliders_2 = column(self.param_hosp_capacity, self.param_gamma_mor1, self.param_gamma_mor2, self.param_gamma_im)
        sliders_0 = column(self.param_eps_exp, self.param_eps_qua, self.param_eps_sev)

        sliders = row(sliders_1, dumdiv3ss, sliders_2, dumdiv3, sliders_0)

        sliders_3 = row(self.param_t_exp, self.param_t_inf, self.param_sim_len,self.datepicker,)
        text2 = Div(text="""<h1 style='color:black'>   issai.nu.edu.kz/episim </h1>""", width = 500, height = 100)
        text_footer_1 = Div(text="""<h3 style='color:green'> Developed by ISSAI Researchers : Askat Kuzdeuov, Daulet Baimukashev, Bauyrzhan Ibragimov, Aknur Karabay, Almas Mirzakhmetov, Mukhamet Nurpeiissov and Huseyin Atakan Varol </h3>""", width = 1500, height = 10)
        text_footer_2 = Div(text="""<h3 style='color:red'> Disclaimer : This simulator is a research tool. The simulation results will show general trends based on entered parameters and initial conditions  </h3>""", width = 1500, height = 10)
        text_footer = column(text_footer_1, text_footer_2)
        text = column(self.text1, text2)

        draw_map_js = CustomJS(code=""" uStates.draw("#statesvg", currRegionData, tooltipHtml); """)
        run_button.js_on_click(draw_map_js)

        layout_t = row(save_button_result, text_save, load_button)
        buttons = row(reset_button,run_button, layout_t)

        reg1 = row(self.text2, region_selection)

        buttons = column(buttons, reg1)

        params =  column(sliders, self.text3, sliders_3, self.text5)

        sliders_4 = column(self.param_tr_scale, self.param_tr_leakage)
        check_table = row(column(div_cb1,self.checkbox_group1), column(div_cb2,self.checkbox_group2), column(div_cb3,self.checkbox_group3), sliders_4)
        check_trans = row(self.data_tableT)

        ###
        dummy_div = Div(text=""" """, height=25);
        dummy_div11 = Div(text=""" """, height=5);
        layout = column(self.pAll, buttons)
        layout = column (layout, dummy_div11, params, check_table)

        layout = column (layout, check_trans, self.text4)

        layout = column (layout)
        layout = column (layout,self.text4)     # text_footer

        self.doc.title = 'ISSAI Covid-19 Simulator'
        self.doc.add_root(layout)
Ejemplo n.º 18
0
    for c in plotted.copy():
        if c not in new:
            print('-', checkbox.labels[c])
            # del_dateline(checkbox.labels[c])
            plotted.remove(c)
    layout.children[1].children=[plot_t(), plot_c()]

def clearcountries():
    checkbox.active=[]

def plottype_handler(new):
    layout.children[1].children[0]=plot_t()

def c_period_handler(attr, old, new):
    calc_lastP(int(new))
    for country, pop in population.items():
        if country+'_lastP' in plottimeline:
            del plottimeline[country+'_lastP']
    layout.children[1].children[1]=plot_c()

checkbox.on_click(update_country)
btn_clear.on_click(clearcountries)
btng_main.on_click(plottype_handler)
slide_c_period.on_change('value', c_period_handler)

# Set up layouts and add to document
inputs = column(btng_main, slide_c_period, btn_clear, checkbox)
outputs = column(plot_t(),plot_c())
layout = row(inputs,outputs)
curdoc().add_root(layout)
curdoc().title = "Plot"
Ejemplo n.º 19
0
class Plots:
    """
    Holds all the parameters that live inside this specific plot as well as all of the graphical elements for it.

    :param name: Name of the plot.
    :param plot_params: List of PlotParameter containing the parameters of this plot.
    :param param_names: List of the names of the parameters.
    """

    def __init__(self, name: str,
                 plot_params: List[PlotParameter],
                 param_names: List[str]):
        # load values
        self.name = name
        self.plot_params = plot_params
        self.param_names = param_names

        self.title = Paragraph(text=self.name, width=1000, height=200,
                               style={
                                   'font-size': '40pt',
                                   'font-weight': 'bold'
                               })

        # set up the tools and the figure itself
        self.tools = 'pan,wheel_zoom,box_zoom,reset,save'
        self.hover_tool = HoverTool(
            tooltips=[
                ('value', '$y'),
                ('time', '$x{%H:%M:%S}')
            ],
            formatters={
                '$x': 'datetime'
            },

            mode='mouse'
        )
        self.ticker_datetime = DatetimeTickFormatter(minsec=['%H:%M:%S %m/%d'])
        # iterator used to cycle through colors
        self.colors = self.colors_gen()

        # setting up the linear figure
        self.fig_linear = figure(width=1000, height=1000,
                                 tools=self.tools, x_axis_type='datetime')
        self.fig_linear.xaxis[0].formatter = self.ticker_datetime

        # setup the hover formatter
        self.fig_linear.add_tools(self.hover_tool)
        # automatically updates the range of the y-axis to center only visible lines
        self.fig_linear.y_range = DataRange1d(only_visible=True)

        # setting up the log figure
        self.fig_log = figure(width=1000, height=1000,
                              tools=self.tools,
                              x_axis_type='datetime', y_axis_type='log')
        self.fig_log.xaxis[0].formatter = self.ticker_datetime

        # setup the hover formatter
        self.fig_log.add_tools(self.hover_tool)
        # automatically updates the range of the y-axis to center only visible lines
        self.fig_log.y_range = DataRange1d(only_visible=True)

        # creates the lines
        self.lines_linear = []
        self.lines_log = []
        for params, c in zip(self.plot_params, self.colors):
            self.lines_linear.append(params.create_line(fig=self.fig_linear, color=c))
            self.lines_log.append(params.create_line(fig=self.fig_log, color=c))

        # Create the checkbox and the buttons and their respective triggers
        self.checkbox = CheckboxGroup(labels=self.param_names, active=list(range(len(self.param_names))))
        self.all_button = Button(label='select all')
        self.none_button = Button(label='deselect all')

        self.checkbox.on_click(self.update_lines)
        self.all_button.on_click(self.all_selected)
        self.none_button.on_click(self.none_selected)

        # create the panels to switch from linear to log
        panels = [Panel(child=self.fig_linear, title='linear'), Panel(child=self.fig_log, title='log')]

        # creates the layout with all of the elements of this plot GUI
        self.layout = column(self.title,
                             row(Tabs(tabs=panels)),
                             self.checkbox,
                             row(self.all_button, self.none_button))

    def colors_gen(self):
        """
        Iterator used to cycle through colors
        """
        yield from itertools.cycle(Category10[10])

    def update_lines(self, active: List[int]):
        """
        Updates which line is visible. gets called everytime either a button gets click or a checkbox gets clicked.

        :param active: List indicating what lines are visible.
        """
        # set each line to be visible or invisible depending on the checkboxes
        for i in range(0, len(self.lines_linear)):
            if i in active:
                self.lines_linear[i].visible = True
                self.lines_log[i].visible = True
            else:
                self.lines_linear[i].visible = False
                self.lines_log[i].visible = False

    def all_selected(self):
        """
        Sets all the lines to be visible. Gets called when the select all button is clicked
        """
        self.checkbox.active = list(range(len(self.lines_linear)))
        self.update_lines(self.checkbox.active)

    def none_selected(self):
        """
        Sets all the lines to be invisible. Gets called when the select none button is clicked
        """
        self.checkbox.active = []
        self.update_lines(self.checkbox.active)

    def update_parameters(self, load_data: pd.DataFrame):
        """
        Reads the source file and updates the parameters
        """

        for params in self.plot_params:
            reduced_data_frame = load_data[load_data['name'] == params.name]
            data = reduced_data_frame['value'].tolist()
            time = reduced_data_frame['time'].tolist()

            params.update(data, time)
Ejemplo n.º 20
0
    def __init__(self, image_views, disp_min=0, disp_max=1000, colormap="plasma"):
        """Initialize a colormapper.

        Args:
            image_views (ImageView): Associated streamvis image view instances.
            disp_min (int, optional): Initial minimal display value. Defaults to 0.
            disp_max (int, optional): Initial maximal display value. Defaults to 1000.
            colormap (str, optional): Initial colormap. Defaults to 'plasma'.
        """
        lin_colormapper = LinearColorMapper(
            palette=cmap_dict[colormap], low=disp_min, high=disp_max
        )

        log_colormapper = LogColorMapper(palette=cmap_dict[colormap], low=disp_min, high=disp_max)

        for image_view in image_views:
            image_view.image_glyph.color_mapper = lin_colormapper

        color_bar = ColorBar(
            color_mapper=lin_colormapper,
            location=(0, -5),
            orientation="horizontal",
            height=15,
            width=100,
            padding=5,
        )
        self.color_bar = color_bar

        # ---- selector
        def select_callback(_attr, _old, new):
            if new in cmap_dict:
                lin_colormapper.palette = cmap_dict[new]
                log_colormapper.palette = cmap_dict[new]
                high_color.color = cmap_dict[new][-1]

        select = Select(
            title="Colormap:", value=colormap, options=list(cmap_dict.keys()), default_size=100
        )
        select.on_change("value", select_callback)
        self.select = select

        # ---- auto toggle button
        def auto_toggle_callback(state):
            if state:
                display_min_spinner.disabled = True
                display_max_spinner.disabled = True
            else:
                display_min_spinner.disabled = False
                display_max_spinner.disabled = False

        auto_toggle = CheckboxGroup(labels=["Auto Colormap Range"], default_size=145)
        auto_toggle.on_click(auto_toggle_callback)
        self.auto_toggle = auto_toggle

        # ---- scale radiobutton group
        def scale_radiobuttongroup_callback(selection):
            if selection == 0:  # Linear
                for image_view in image_views:
                    image_view.image_glyph.color_mapper = lin_colormapper
                color_bar.color_mapper = lin_colormapper
                color_bar.ticker = BasicTicker()

            else:  # Logarithmic
                if self.disp_min > 0:
                    for image_view in image_views:
                        image_view.image_glyph.color_mapper = log_colormapper
                    color_bar.color_mapper = log_colormapper
                    color_bar.ticker = LogTicker()
                else:
                    scale_radiobuttongroup.active = 0

        scale_radiobuttongroup = RadioGroup(
            labels=["Linear", "Logarithmic"], active=0, default_size=145
        )
        scale_radiobuttongroup.on_click(scale_radiobuttongroup_callback)
        self.scale_radiobuttongroup = scale_radiobuttongroup

        # ---- display max value
        def display_max_spinner_callback(_attr, _old_value, new_value):
            self.display_min_spinner.high = new_value - STEP
            if new_value <= 0:
                scale_radiobuttongroup.active = 0

            lin_colormapper.high = new_value
            log_colormapper.high = new_value

        display_max_spinner = Spinner(
            title="Max Display Value:",
            low=disp_min + STEP,
            value=disp_max,
            step=STEP,
            disabled=bool(auto_toggle.active),
            default_size=145,
        )
        display_max_spinner.on_change("value", display_max_spinner_callback)
        self.display_max_spinner = display_max_spinner

        # ---- display min value
        def display_min_spinner_callback(_attr, _old_value, new_value):
            self.display_max_spinner.low = new_value + STEP
            if new_value <= 0:
                scale_radiobuttongroup.active = 0

            lin_colormapper.low = new_value
            log_colormapper.low = new_value

        display_min_spinner = Spinner(
            title="Min Display Value:",
            high=disp_max - STEP,
            value=disp_min,
            step=STEP,
            disabled=bool(auto_toggle.active),
            default_size=145,
        )
        display_min_spinner.on_change("value", display_min_spinner_callback)
        self.display_min_spinner = display_min_spinner

        # ---- high color
        def high_color_callback(_attr, _old_value, new_value):
            lin_colormapper.high_color = new_value
            log_colormapper.high_color = new_value

        high_color = ColorPicker(
            title="High Color:", color=cmap_dict[colormap][-1], default_size=90
        )
        high_color.on_change("color", high_color_callback)
        self.high_color = high_color

        # ---- mask color
        def mask_color_callback(_attr, _old_value, new_value):
            lin_colormapper.nan_color = new_value
            log_colormapper.nan_color = new_value

        mask_color = ColorPicker(title="Mask Color:", color="gray", default_size=90)
        mask_color.on_change("color", mask_color_callback)
        self.mask_color = mask_color
Ejemplo n.º 21
0
    <div>
        <div>
            <span style="font-size: 12px;">@CapitalName_x, @{Country/Region}</span>
        </div>
    </div>
"""
plot.add_tools(
    HoverTool(renderers=[capitals_glyph], tooltips=capitals_tooltips))

capitals_checkbox = CheckboxGroup(labels=['Show capitals'], active=[0])


def show_capitals_callback(active):
    """
    Callback function for the 'Show capitals' checkbox. If the checkbox is active, show the countries' capitals and hover tool. Otherwise,
    hide them.
    """

    if active:
        capitals_glyph.visible = True
    else:
        capitals_glyph.visible = False


capitals_checkbox.on_click(show_capitals_callback)

plot.background_fill_color = '#f0f0f0'

layout = column(capitals_checkbox, plot)
curdoc().add_root(layout)
Ejemplo n.º 22
0
          line_width=3,
          color='red')
plot.line('Time (s)',
          'Contraction (%)',
          source=source,
          legend='Material',
          line_width=3)
# plot.circle('Time (s)', 'Temperature (°C)', source = source, legend = 'Material', size = 10, color = dict(field = 'Material', transform=color_mapper) )

# CREATE LOADS AND CIs CHECKBOXES
mats_menu = CheckboxGroup(labels=mat_list, active=[0])
loads_menu = CheckboxGroup(labels=load_list, active=[0])
cis_menu = CheckboxGroup(labels=ci_list, active=[0])

# ADD UPDATE EVENT LISTENING
mats_menu.on_click(update)
loads_menu.on_click(update)
cis_menu.on_click(update)

# ADD A TITLE FOR BOTH CHECKBOXES
mat_title = Div(text="""<h3>Material:</h3>""", width=200)
loads_title = Div(text="""<h3>Load (g):</h3>""", width=200)
cis_title = Div(text="""<h3>Coil Index:</h3>""", width=200)

# ADD PLOT TITLE
plot.title.text = 'Temperature & Contraction over Time'

# ADJUST LEGEND POSITION
plot.legend.location = 'top_left'

# CREATE THE LAYOUT
Ejemplo n.º 23
0
class Plot:
    def __init__(self, parent, dataset, parameters):

        self.parent = parent
        self.dataset = dataset

        # Set up the controls
        self.tools = Selector(
            name="Beverages",
            descr="Choose a plotting tool",
            kind="tools",
            css_classes=["tools"],
            entries={"Deli-LATTE": tools.DeliLATTE},
            default="None",
            none_allowed=True,
        )
        self.data = Selector(
            name="Main Dishes",
            descr="Choose a dataset",
            kind="datasets",
            css_classes=["data"],
            entries={
                "Test data": "test",
                # "TOI Catalog": "toi",
                # "Confirmed Planets": "confirmed",
            },
            default="Test data",
        )
        self.xaxis = Selector(
            name="Build-Your-Own",
            descr="Choose the parameters to plot",
            kind="parameters",
            css_classes=["build-your-own"],
            entries=parameters,
            default="ra",
            title="X Axis",
        )
        self.yaxis = Selector(
            kind="parameters",
            css_classes=["build-your-own"],
            entries=parameters,
            default="dist",
            title="Y Axis",
        )
        self.size = Selector(
            name="Sides",
            descr="Choose additional parameters to plot",
            kind="parameters",
            css_classes=["sides"],
            entries=parameters,
            default="None",
            title="Marker Size",
            none_allowed=True,
        )
        self.color = Selector(
            kind="parameters",
            css_classes=["sides"],
            entries=parameters,
            default="None",
            title="Marker Color",
            none_allowed=True,
        )
        self.checkbox_labels = [
            "Flip x-axis",
            "Flip y-axis ",
            "Log scale x-axis",
            "Log scale y-axis",
        ]
        self.specials = Selector(
            name="Specials",
            descr="Choose a special",
            kind="specials",
            css_classes=["specials"],
            entries={},
            default="None",
            none_allowed=True,
        )

        self.checkbox_group = CheckboxGroup(labels=self.checkbox_labels,
                                            active=[])

        self.source = ColumnDataSource(
            data=dict(x=[], y=[], size=[], color=[]))

        # Register the callbacks
        for control in [self.xaxis, self.yaxis, self.size, self.color]:
            control.widget.on_change("value", self.param_callback)
        self.tools.widget.on_change("value", self.tool_callback)
        self.data.widget.on_change("value", self.data_callback)
        self.checkbox_group.on_click(self.checkbox_callback)

        # Setup the plot
        self.setup_plot()

        # Load and display the data
        self.param_callback(None, None, None)

    def setup_plot(
        self,
        x_axis_type="linear",
        y_axis_type="linear",
        x_flip=False,
        y_flip=False,
    ):
        # Set up the plot
        self.plot = figure(
            plot_height=620,
            min_width=600,
            title="",
            x_axis_type=x_axis_type,
            y_axis_type=y_axis_type,
            tools="",
            sizing_mode="stretch_both",
        )

        # Enable Bokeh tools
        self.plot.add_tools(PanTool(), TapTool(), ResetTool())

        # Axes orientation and labels
        self.plot.x_range.flipped = x_flip
        self.plot.y_range.flipped = y_flip
        self.plot.xaxis.axis_label = self.xaxis.value
        self.plot.yaxis.axis_label = self.yaxis.value

        # Plot the data
        self.plot.circle(
            x="x",
            y="y",
            source=self.source,
            size="size",
            color=linear_cmap(field_name="color",
                              palette=Viridis256,
                              low=0,
                              high=1),
            line_color=None,
        )

        # -- HACKZ --

        # Update the plot element in the HTML layout
        if hasattr(self.parent, "layout"):
            self.parent.layout.children[0].children[-1] = self.plot

        # Make the cursor a grabber when panning
        code_pan_start = """
            Bokeh.grabbing = true
            var elm = document.getElementsByClassName('bk-canvas-events')[0]
            elm.style.cursor = 'grabbing'
        """
        code_pan_end = """
            if(Bokeh.grabbing) {
                Bokeh.grabbing = false
                var elm = document.getElementsByClassName('bk-canvas-events')[0]
                elm.style.cursor = 'grab'
            }
        """
        self.plot.js_on_event("panstart", CustomJS(code=code_pan_start))
        self.plot.js_on_event("panend", CustomJS(code=code_pan_end))

        # Add a hover tool w/ a pointer cursor
        code_hover = """
        if((Bokeh.grabbing == 'undefined') || !Bokeh.grabbing) {
            var elm = document.getElementsByClassName('bk-canvas-events')[0]
            if (cb_data.index.indices.length > 0) {
                elm.style.cursor = 'pointer'
                Bokeh.pointing = true
            } else {
                if((Bokeh.pointing == 'undefined') || !Bokeh.pointing)
                    elm.style.cursor = 'grab'
                else
                    Bokeh.pointing = false
            }
        }
        """
        self.plot.add_tools(
            HoverTool(
                callback=CustomJS(code=code_hover),
                tooltips=[("TIC ID", "@ticid")],
            ))

    def tool_callback(self, attr, old, new):
        if self.tools.value != "None":
            self.parent.change_tool(self.tools.entries[self.tools.value])
        else:
            self.parent.change_tool(tools.BaseTool)

    def data_callback(self, attr, old, new):
        # TODO: Change datasets!
        pass

    def param_callback(self, attr, old, new):
        """
        Triggered when the user changes what we're plotting on the main plot.

        """
        # Update the axis labels
        x_name = self.xaxis.entries[self.xaxis.value]
        y_name = self.yaxis.entries[self.yaxis.value]
        self.plot.xaxis.axis_label = self.xaxis.value
        self.plot.yaxis.axis_label = self.yaxis.value

        # Update the "sides"
        if self.size.value != "None":
            s_name = self.size.entries[self.size.value]
            size = (
                25 * (self.dataset[s_name] - np.min(self.dataset[s_name])) /
                (np.max(self.dataset[s_name]) - np.min(self.dataset[s_name])))
        else:
            size = np.ones_like(self.dataset["ticid"]) * 5

        if self.color.value != "None":
            c_name = self.color.entries[self.color.value]
            color = (self.dataset[c_name] - np.min(self.dataset[c_name])) / (
                np.max(self.dataset[c_name]) - np.min(self.dataset[c_name]))
        else:
            color = np.zeros_like(self.dataset["ticid"])

        # Update the data source
        self.source.data = dict(
            x=self.dataset[x_name],
            y=self.dataset[y_name],
            size=size,
            ticid=self.dataset["ticid"],
            color=color,
        )

    def checkbox_callback(self, new):
        """
        Triggered when the user interacts with check boxes in appearance panel.

        """
        if 0 in self.checkbox_group.active:
            x_flip = True
        else:
            x_flip = False
        if 1 in self.checkbox_group.active:
            y_flip = True
        else:
            y_flip = False
        if 2 in self.checkbox_group.active:
            x_axis_type = "log"
        else:
            x_axis_type = "linear"
        if 3 in self.checkbox_group.active:
            y_axis_type = "log"
        else:
            y_axis_type = "linear"

        # Axis labels are disappearing on selection of checkboxes
        self.setup_plot(
            x_axis_type=x_axis_type,
            y_axis_type=y_axis_type,
            x_flip=x_flip,
            y_flip=y_flip,
        )

    def layout(self):
        panels = [None, None, None]

        # Main panel: data
        panels[0] = Panel(
            child=row(
                row(
                    column(
                        self.data.layout(),
                        Spacer(height=10),
                        self.tools.layout(),
                        width=160,
                    ),
                    Spacer(width=10),
                    column(
                        self.xaxis.layout([self.yaxis.widget]),
                        Spacer(height=10),
                        self.size.layout([self.color.widget]),
                    ),
                    css_classes=["panel-inner"],
                ),
                css_classes=["panel-outer"],
            ),
            title="Main Menu",
        )

        # Secondary panel: prix fixe
        panels[1] = Panel(
            child=row(
                row(self.specials.layout(), css_classes=["panel-inner"]),
                css_classes=["panel-outer"],
            ),
            title="Prix Fixe",
        )

        # Tertiary panel: appearance
        checkbox_panel = row(
            row(
                column(
                    Div(
                        text=
                        """<h2>Toppings</h2><h3>Choose axes transforms</h3>""",
                        css_classes=["controls"],
                    ),
                    row(
                        self.checkbox_group,
                        width=160,
                        css_classes=["controls"],
                    ),
                    css_classes=["axes-checkboxes"],
                ),
                css_classes=["panel-inner"],
            ),
            css_classes=["panel-outer"],
        )
        panels[2] = Panel(child=checkbox_panel, title="Garnishes")

        # All tabs
        tabs = Tabs(tabs=panels, css_classes=["tabs"])

        # Logo
        header = Div(
            text=f"""<img src="{LOGO_URL}"></img>""",
            css_classes=["header-image"],
            width=320,
            height=100,
        )

        return row(column(header, tabs), Spacer(width=30), self.plot)
Ejemplo n.º 24
0
    def __init__(self,
                 image_views,
                 sv_metadata,
                 sv_streamctrl,
                 positions=POSITIONS):
        """Initialize a resolution rings overlay.

        Args:
            image_views (ImageView): Associated streamvis image view instances.
            sv_metadata (MetadataHandler): A metadata handler to report metadata issues.
            sv_streamctrl (StreamControl): A StreamControl instance of an application.
            positions (list, optional): Scattering radii in Angstroms. Defaults to
                [1.4, 1.5, 1.6, 1.8, 2, 2.2, 2.6, 3, 5, 10].
        """
        self._sv_metadata = sv_metadata
        self._sv_streamctrl = sv_streamctrl
        self.positions = np.array(positions)

        # ---- add resolution tooltip to hover tool
        self._formatter_source = ColumnDataSource(data=dict(
            detector_distance=[np.nan],
            beam_energy=[np.nan],
            beam_center_x=[np.nan],
            beam_center_y=[np.nan],
        ))

        resolution_formatter = CustomJSHover(
            args=dict(params=self._formatter_source), code=js_resolution)

        hovertool_off = HoverTool(tooltips=[("intensity", "@image")],
                                  names=["image_glyph"])
        hovertool_on = HoverTool(
            tooltips=[("intensity", "@image"),
                      ("resolution", "@x{resolution} Å")],
            formatters={"@x": resolution_formatter},
            names=["image_glyph"],
        )

        # ---- resolution rings
        self._source = ColumnDataSource(
            dict(x=[], y=[], w=[], h=[], text_x=[], text_y=[], text=[]))
        ellipse_glyph = Ellipse(x="x",
                                y="y",
                                width="w",
                                height="h",
                                fill_alpha=0,
                                line_color="white")

        text_glyph = Text(
            x="text_x",
            y="text_y",
            text="text",
            text_align="center",
            text_baseline="middle",
            text_color="white",
        )

        cross_glyph = Cross(x="beam_center_x",
                            y="beam_center_y",
                            size=15,
                            line_color="red")

        for image_view in image_views:
            image_view.plot.add_glyph(self._source, ellipse_glyph)
            image_view.plot.add_glyph(self._source, text_glyph)
            image_view.plot.add_glyph(self._formatter_source, cross_glyph)

        # ---- toggle button
        def toggle_callback(state):
            hovertool = hovertool_on if state else hovertool_off
            for image_view in image_views:
                image_view.plot.tools[-1] = hovertool

        toggle = CheckboxGroup(labels=["Resolution Rings"], default_size=145)
        toggle.on_click(toggle_callback)
        self.toggle = toggle
Ejemplo n.º 25
0
def checkbox_group_update(attrname):
    start_year = range_slider.value[0]
    end_year = range_slider.value[1]
    update_data(start_year, end_year)


LABELS = [
    "Residential", "Vacant Land", "Commercial", "Entertainment",
    "Community Services", "Industrial", "Public Services", "Parks"
]
default = [0]
checkbox_group = CheckboxGroup(labels=LABELS,
                               active=default,
                               margin=(5, 5, 5, 5),
                               inline=True)
checkbox_group.on_click(checkbox_group_update)

# create layout
layout = layout(
    [
        [div],
        [range_slider],
        [checkbox_group],
        [p],
    ],
    sizing_mode="stretch_width",
    margin=(10, 50, 10, 50),
)

curdoc().add_root(layout)
curdoc().title = "Visual"
Ejemplo n.º 26
0
                most_recent_seq_num = m.seq_number
            doc.add_next_tick_callback(partial(update, m=m))


def send_thread_fn():
    global connected
    global most_recent_seq_num
    m = HeartbeatMsg()
    while True:
        if connected:
            m.ack_num = most_recent_seq_num
            send_sock.sendto(bytes(m), (SEND_UDP_IP, SEND_UDP_PORT))
        time.sleep(1.0)


def connected_checkbox_on_click(new):
    global connected
    connected = not bool(new[0])


connected_checkbox = CheckboxGroup(labels=["Connected"], active=[1])
connected_checkbox.on_click(connected_checkbox_on_click)

doc.add_root(column(fig, connected_checkbox))

recv_thread = threading.Thread(target=recv_thread_fn, daemon=True)
recv_thread.start()

send_thread = threading.Thread(target=send_thread_fn, daemon=True)
send_thread.start()
Ejemplo n.º 27
0
class StageViewer(Component):
    name = 'StageViewer'

    def __init__(self, config, tool, **kwargs):
        """
        Bokeh plot for showing the spread of TF across cells

        Parameters
        ----------
        config : traitlets.loader.Config
            Configuration specified by config file or cmdline arguments.
            Used to set traitlet values.
            Set to None if no configuration to pass.
        tool : ctapipe.core.Tool
            Tool executable that is calling this component.
            Passes the correct logger to the component.
            Set to None if no Tool to pass.
        kwargs
        """
        super().__init__(config=config, parent=tool, **kwargs)
        self._active_pixel = 0

        self.figures = None
        self.cdsources = None
        self.lines = None

        self.x = None
        self.stages = None
        self.neighbours2d = None
        self.stage_list = None
        self.active_stages = None
        self.w_pulse = None
        self.w_integration = None

        self.cb = None
        self.pulsewin1 = []
        self.pulsewin2 = []
        self.intwin1 = []
        self.intwin2 = []

        self.layout = None

    def create(self, neighbours2d, stage_list):

        self.neighbours2d = neighbours2d
        self.stage_list = stage_list

        palette = palettes.Set1[9]

        self.figures = []
        self.cdsources = []
        legend_list = []
        self.lines = defaultdict(list)
        for i in range(9):
            fig = figure(plot_width=400,
                         plot_height=200,
                         tools="",
                         toolbar_location=None,
                         outline_line_color='#595959')
            cdsource_d = dict(x=[])
            for stage in self.stage_list:
                cdsource_d[stage] = []
            cdsource = ColumnDataSource(data=cdsource_d)
            for j, stage in enumerate(self.stage_list):
                color = palette[j % len(palette)]
                l = fig.line(source=cdsource, x='x', y=stage, color=color)
                if not j == 0 and not j == 7:
                    l.visible = False
                self.lines[stage].append(l)
                if i == 2:
                    legend_list.append((stage, [l]))

            self.figures.append(fig)
            self.cdsources.append(cdsource)

            self.pulsewin1.append(
                Span(location=0,
                     dimension='height',
                     line_color='red',
                     line_dash='dotted'))
            self.pulsewin2.append(
                Span(location=0,
                     dimension='height',
                     line_color='red',
                     line_dash='dotted'))
            self.intwin1.append(
                Span(location=0,
                     dimension='height',
                     line_color='green',
                     line_dash='dotted'))
            self.intwin2.append(
                Span(location=0,
                     dimension='height',
                     line_color='green',
                     line_dash='dotted'))
            fig.add_layout(self.pulsewin1[i])
            fig.add_layout(self.pulsewin2[i])
            fig.add_layout(self.intwin1[i])
            fig.add_layout(self.intwin2[i])

            if i == 2:
                legend = Legend(items=legend_list,
                                location=(0, 0),
                                background_fill_alpha=0,
                                label_text_color='green')
                fig.add_layout(legend, 'right')

        self.cb = CheckboxGroup(labels=self.stage_list, active=[0, 5])
        self.cb.on_click(self._on_checkbox_select)
        self.active_stages = [self.stage_list[i] for i in self.cb.active]

        figures = layout([[self.figures[0], self.figures[1], self.figures[2]],
                          [self.figures[3], self.figures[4], self.figures[5]],
                          [self.figures[6], self.figures[7], self.figures[8]]])
        self.layout = layout([[self.cb, figures]])

    @property
    def active_pixel(self):
        return self._active_pixel

    @active_pixel.setter
    def active_pixel(self, val):
        if not self._active_pixel == val:
            self._active_pixel = val
            self.update_stages(self.x, self.stages, self.w_pulse,
                               self.w_integration)

    def _get_neighbour_pixel(self, i):
        pixel = self.neighbours2d[self.active_pixel].ravel()[i]
        if np.isnan(pixel):
            return None
        return int(pixel)

    def update_stages(self, x, stages, w_pulse, w_integration):
        self.x = x
        self.stages = stages
        self.w_pulse = w_pulse
        self.w_integration = w_integration
        for i, cdsource in enumerate(self.cdsources):
            pixel = self._get_neighbour_pixel(i)
            if pixel is None:
                cdsource_d = dict(x=[])
                for stage in self.stage_list:
                    cdsource_d[stage] = []
                cdsource.data = cdsource_d
            else:
                cdsource_d = dict(x=x)
                for stage in self.stage_list:
                    values = self.stages[stage]
                    if values.ndim == 2:
                        cdsource_d[stage] = values[pixel]
                    elif values.ndim == 1:
                        cdsource_d[stage] = values
                    else:
                        self.log.error("Too many dimensions in stage values")
                cdsource.data = cdsource_d
                pixel_w_pulse = w_pulse[pixel]
                length = np.sum(pixel_w_pulse, axis=0)
                pw_l = np.argmax(pixel_w_pulse)
                pw_r = pw_l + length - 1
                pixel_w_integration = w_integration[pixel]
                length = np.sum(pixel_w_integration)
                iw_l = np.argmax(pixel_w_integration)
                iw_r = iw_l + length - 1
                self.pulsewin1[i].location = pw_l
                self.pulsewin2[i].location = pw_r
                self.intwin1[i].location = iw_l
                self.intwin2[i].location = iw_r
        sleep(0.1)
        self._update_yrange()

    def toggle_stage(self, stage, value):
        lines = self.lines[stage]
        for l in lines:
            l.visible = value

    def _on_checkbox_select(self, active):
        self.active_stages = [self.stage_list[i] for i in self.cb.active]
        for stage in self.stage_list:
            if stage in self.active_stages:
                self.toggle_stage(stage, True)
            else:
                self.toggle_stage(stage, False)
        self._update_yrange()

    def _update_yrange(self):
        for fig, cdsource in zip(self.figures, self.cdsources):
            try:
                min_l = []
                max_l = []
                for stage in self.active_stages:
                    array = cdsource.data[stage]
                    min_l.append(min(array))
                    max_l.append(max(array))
                min_ = min(min_l)
                max_ = max(max_l)
                fig.y_range.start = min_
                fig.y_range.end = max_
            except ValueError:
                pass  # Edge of camera