示例#1
0
 def createButtons(self, figNames):
     """helper function that creates all my toggle buttons based on FIGNAMES"""
     toggles = []
     for name in figNames:
         toggle = Toggle(label=name, active=True)
         toggleFunction = self.toggleFunction(toggle)
         toggle.on_click(toggleFunction)
         toggles.append(toggle)
     return toggles
示例#2
0
 def modify_doc(doc):
     source = ColumnDataSource(dict(x=[1, 2], y=[1, 1]))
     plot = Plot(plot_height=400, plot_width=400, x_range=Range1d(0, 1), y_range=Range1d(0, 1), min_border=0)
     plot.add_glyph(source, Circle(x='x', y='y', size=20))
     plot.add_tools(CustomAction(callback=CustomJS(args=dict(s=source), code=RECORD("data", "s.data"))))
     button = Toggle(css_classes=['foo'])
     def cb(value):
         if value:
             source.data=dict(x=[10, 20], y=[10, 10])
         else:
             source.data=dict(x=[100, 200], y=[100, 100])
     button.on_click(cb)
     doc.add_root(column(button, plot))
示例#3
0
 def modify_doc(doc):
     source = ColumnDataSource(dict(x=[1, 2], y=[1, 1]))
     plot = Plot(plot_height=400, plot_width=400, x_range=Range1d(0, 1), y_range=Range1d(0, 1), min_border=0)
     plot.add_glyph(source, Circle(x='x', y='y', size=20))
     plot.add_tools(CustomAction(callback=CustomJS(args=dict(s=source), code=RECORD("data", "s.data"))))
     button = Toggle(css_classes=['foo'])
     def cb(value):
         if value:
             source.data=dict(x=[10, 20], y=[10, 10])
         else:
             source.data=dict(x=[100, 200], y=[100, 100])
     button.on_click(cb)
     doc.add_root(column(button, plot))
示例#4
0
class button():
    def __init__(self, widget_lst, label):
        self.widget_lst = widget_lst
        self.button = None
        self.initialize(label)

    def initialize(self, label):
        # self.button = Button(label = label) # Using a simple button doesn't trigger callbacks for some reason
        self.button = Toggle(label=label)
        self.widget_lst.append(self.button)

    def add_callback(self, callback):
        self.button.on_click(callback)
示例#5
0
class button():
    def __init__(self, widget_lst, label):
        self.label = label
        self.button = None
        self.callback = None
        # self.initialize(default_value, label, inline)
        widget_lst.append(self)

    def initialize(self, widget_lst):
        # self.button = Button(label = label) # Using a simple button doesn't trigger callbacks for some reason
        self.button = Toggle(label = self.label)
        widget_lst.append(self.button)
        if self.callback is not None:
            self.button.on_click(self.callback)

    def add_callback(self, callback):
        self.callback = callback
        if self.button is not None:
            self.button.on_click(self.callback)
示例#6
0
        error_points.data = dict(x=new_line_x, y=new_line_y)

        error_param = error_land_data.data['x']
        error_val = [utils.compute_error(new_x, new_y, i) for i in error_param]
        error_land_data.data = dict(x=error_param, y=error_val)

    else:
        cluttered_button.button_type = 'primary'
        scatter_data.data = {'x': x, 'y': y}

        error_param = error_land_data.data['x']
        error_val = [utils.compute_error(x, y, i) for i in error_param]
        error_land_data.data = dict(x=error_param, y=error_val)


cluttered_button.on_click(clutter_button_callback)


def button_draw_error_callback(attr):
    if button_draw_error.active:
        button_draw_error.button_type = 'success'
        error_glyph.line_alpha = 1.0
    else:
        button_draw_error.button_type = 'primary'
        error_glyph.line_alpha = 0.0


button_draw_error.on_click(button_draw_error_callback)

button_row = row(cluttered_button, button_draw_error)
plots = row(pair_plot, error_plot)
示例#7
0
## Create slider to select v0-y
v0_input_y = Slider(title=u"v\u2092-y",
                    value=-2.0,
                    start=-5.0,
                    end=5.0,
                    step=0.5)
v0_input_y.on_change('value', particle_speed_y)

## Create reset button
reset_button = Button(label="Reset", button_type="success")
reset_button.on_click(reset_situation)

## Create pause button
pause_button = Toggle(label="Pause", button_type="success")
pause_button.on_click(pause)

## Create re-initialise button
reinit_button = Button(label="Re-initialise", button_type="success")
reinit_button.on_click(BackToInitial)

## Create play button
play_button = Button(label="Play", button_type="success")
play_button.on_click(play)

## Create choice of referential button
Referential_button = RadioGroup(
    labels=["Reference frame: Room", "Reference frame: Disk"], active=0)
Referential_button.on_change('active', chooseRef)

## create drawing
示例#8
0

def callback_plot(event):
    if p_power.right != []:
        p_power.right = []

    nm_int = np.array(str.split(selected_nm.value, ',')).astype(int)
    power_dep = p_dep(nm_int)
    source_power = {'index': power_dep.index.astype(int).to_list()}
    source_power.update(power_dep.to_dict(orient='list'))

    p_power.renderers = [p_power.renderers[0]]
    for value, color in zip(power_dep.columns.astype(str), Viridis3):
        p_power.line(x='index', y=value, source=source_power, color=color)

    legend = Legend(items=legend_plot(p_power, nm_int.astype(str)))
    #p_power.add_layout(legend, 'right')


# event activation
nm.on_change('value', callback_nm)
time.on_change('value', callback_time)
p_center.on_event(events.Tap, callback)
power_choice.on_change('value', callback_power, callback_nm, callback_time)
show_all_power.on_click(callback_all_power)
plot_power.on_click(callback_plot)

# layout config
layout = gridplot([[p_top, None, widget], [nm], [p_center, time, p_right],
                   [p_power]])
curdoc().add_root(layout)
示例#9
0
def modify_doc(doc):  # plotter
    # Initialize
    spec_data = ColumnDataSource(data=dict(x=[], color=[]))
    x_initial = float(0)
    new_data = {'x': [x_initial], 'color': ['orange']}
    spec_data.data = new_data
    fname_list = [  # TODO: get this from API
        'data/sn2000cx-20000723-nickel.flm', 'data/sn2000cx-20000728-ui.flm',
        'data/sn2000cx-20000801-bluered.flm',
        'data/sn2000cx-20000802-bluered.flm',
        'data/sn2000cx-20000803-nickel.flm',
        'data/sn2000cx-20000805-nickel.flm',
        'data/sn2000cx-20000807-nickel.flm',
        'data/sn2000cx-20000810-nickel.flm',
        'data/sn2000cx-20000815-nickel.flm',
        'data/sn2000cx-20000818-nickel.flm',
        'data/sn2000cx-20000820-nickel.flm',
        'data/sn2000cx-20000822-nickel.flm',
        'data/sn2000cx-20000824-nickel.flm',
        'data/sn2000cx-20000826-nickel.flm', 'data/sn2000cx-20000827-ui.flm',
        'data/sn2000cx-20000906-ui.flm', 'data/sn2000cx-20000926-ui.flm',
        'data/sn2000cx-20001006-ui.flm', 'data/sn2000cx-20001024-ui.flm',
        'data/sn2000cx-20001101-ui.flm', 'data/sn2000cx-20001129-ui.flm',
        'data/sn2000cx-20001221-ui.flm'
    ]

    def legend_showhide():
        Labels = ['Hide Legend', 'Show Legend']
        status = legend_toggle.active
        legend_toggle.label = Labels[status]
        plot.legend.visible = False if status else True

    # Plot
    plot = make_fig()
    raw_plot(fname_list,
             'SN2000CX',
             plot,
             return_figure=False,
             show_figure=True)

    # emission lines
    line_list, z_in_list, checkboxes = emission_lines(emission_lines_data)
    for line in line_list:
        for l in line:
            plot.renderers.extend([l])

    # other widgets
    legend_toggle = Toggle(label='Hide Legend',
                           button_type='success',
                           sizing_mode='scale_width')
    legend_toggle.on_click(lambda new: legend_showhide())

    # layout
    line_layout_L = []
    line_layout_R = []
    p1 = Div(text="z=", sizing_mode='scale_width')
    p2 = Div(text="v=", sizing_mode='scale_width')
    for i in range(int(len(z_in_list) / 2)):
        line_layout_L.append(row([checkboxes[i],p,z_in_list[i]],\
                sizing_mode='scale_width'))
    for j in range(len(z_in_list[i:])):
        line_layout_R.append(row([checkboxes[i:][j],p,z_in_list[i:][j]],\
                sizing_mode='scale_width'))
    layout = column([
        legend_toggle, plot,
        row(column(line_layout_L, sizing_mode='scale_width'),
            column(line_layout_R, sizing_mode='scale_width'),
            sizing_mode='scale_width')
    ],
                    sizing_mode='scale_width')

    # show
    doc.add_root(layout)
示例#10
0
plot.extra_y_ranges['secondary'] = Range1d(0, 100)

# select file
file_selection_button = Button(label="Select Files", button_type="success", width=120)
file_selection_button.on_click(load_files_group)

files_selector_spacer = Spacer(width=10)

group_selection_button = Button(label="Select Directory", button_type="primary", width=140)
group_selection_button.on_click(load_directory_group)

update_files_button = Button(label="Update Files", button_type="default", width=50)
update_files_button.on_click(reload_all_files)

auto_update_toggle_button = Toggle(label="Auto Update", button_type="default", width=50, active=True)
auto_update_toggle_button.on_click(toggle_auto_update)

unload_file_button = Button(label="Unload", button_type="danger", width=50)
unload_file_button.on_click(unload_file)

# files selection box
files_selector = Select(title="Files:", options=[""])
files_selector.on_change('value', change_data_selector)

# data selection box
data_selector = MultiSelect(title="Data:", options=[], size=12)
data_selector.on_change('value', select_data)

# x axis selection box
x_axis_selector_title = Div(text="""X Axis:""", height=10)
x_axis_selector = RadioButtonGroup(labels=x_axis_options, active=0)
示例#11
0
class Samples(object):
    def __init__(
        self,
        ydeg,
        npix,
        npts,
        nmaps,
        throttle_time,
        nosmooth,
        gp,
        sample_function,
    ):
        # Settings
        self.ydeg = ydeg
        self.npix = npix
        self.npts = npts
        self.throttle_time = throttle_time
        self.nosmooth = nosmooth
        self.nmaps = nmaps
        self.gp = gp

        # Design matrices
        self.A_I = get_intensity_design_matrix(ydeg, npix)
        self.A_F = get_flux_design_matrix(ydeg, npts)

        def sample_ylm(r, mu_l, sigma_l, c, n):
            # Avoid issues at the boundaries
            if mu_l == 0:
                mu_l = 1e-2
            elif mu_l == 90:
                mu_l = 90 - 1e-2
            a, b = gauss2beta(mu_l, sigma_l)
            return sample_function(r, a, b, c, n)

        self.sample_ylm = sample_ylm

        # Draw three samples from the default distr
        self.ylm = self.sample_ylm(
            params["size"]["r"]["value"],
            params["latitude"]["mu"]["value"],
            params["latitude"]["sigma"]["value"],
            params["contrast"]["c"]["value"],
            params["contrast"]["n"]["value"],
        )[0]

        # Plot the GP ylm samples
        self.color_mapper = LinearColorMapper(palette="Plasma256",
                                              nan_color="white",
                                              low=0.5,
                                              high=1.2)
        self.moll_plot = [None for i in range(self.nmaps)]
        self.moll_source = [
            ColumnDataSource(data=dict(image=[
                1.0 +
                (self.A_I @ self.ylm[i]).reshape(self.npix, 2 * self.npix)
            ])) for i in range(self.nmaps)
        ]
        eps = 0.1
        epsp = 0.02
        xe = np.linspace(-2, 2, 300)
        ye = 0.5 * np.sqrt(4 - xe**2)
        for i in range(self.nmaps):
            self.moll_plot[i] = figure(
                plot_width=280,
                plot_height=130,
                toolbar_location=None,
                x_range=(-2 - eps, 2 + eps),
                y_range=(-1 - eps / 2, 1 + eps / 2),
            )
            self.moll_plot[i].axis.visible = False
            self.moll_plot[i].grid.visible = False
            self.moll_plot[i].outline_line_color = None
            self.moll_plot[i].image(
                image="image",
                x=-2,
                y=-1,
                dw=4 + epsp,
                dh=2 + epsp / 2,
                color_mapper=self.color_mapper,
                source=self.moll_source[i],
            )
            self.moll_plot[i].toolbar.active_drag = None
            self.moll_plot[i].toolbar.active_scroll = None
            self.moll_plot[i].toolbar.active_tap = None

        # Plot lat/lon grid
        lat_lines = get_latitude_lines()
        lon_lines = get_longitude_lines()
        for i in range(self.nmaps):
            for x, y in lat_lines:
                self.moll_plot[i].line(x,
                                       y,
                                       line_width=1,
                                       color="black",
                                       alpha=0.25)
            for x, y in lon_lines:
                self.moll_plot[i].line(x,
                                       y,
                                       line_width=1,
                                       color="black",
                                       alpha=0.25)
            self.moll_plot[i].line(xe,
                                   ye,
                                   line_width=3,
                                   color="black",
                                   alpha=1)
            self.moll_plot[i].line(xe,
                                   -ye,
                                   line_width=3,
                                   color="black",
                                   alpha=1)

        # Colorbar slider
        self.slider = RangeSlider(
            start=0,
            end=1.5,
            step=0.01,
            value=(0.5, 1.2),
            orientation="horizontal",
            show_value=False,
            css_classes=["colorbar-slider"],
            direction="ltr",
            title="cmap",
        )
        self.slider.on_change("value", self.slider_callback)

        # Buttons
        self.seed_button = Button(
            label="re-seed",
            button_type="default",
            css_classes=["seed-button"],
            sizing_mode="fixed",
            height=30,
            width=75,
        )
        self.seed_button.on_click(self.seed_callback)

        self.smooth_button = Toggle(
            label="smooth",
            button_type="default",
            css_classes=["smooth-button"],
            sizing_mode="fixed",
            height=30,
            width=75,
            active=False,
        )
        self.smooth_button.disabled = bool(self.nosmooth)
        self.smooth_button.on_click(self.smooth_callback)

        self.auto_button = Toggle(
            label="auto",
            button_type="default",
            css_classes=["auto-button"],
            sizing_mode="fixed",
            height=30,
            width=75,
            active=True,
        )

        self.reset_button = Button(
            label="reset",
            button_type="default",
            css_classes=["reset-button"],
            sizing_mode="fixed",
            height=30,
            width=75,
        )
        self.reset_button.on_click(self.reset_callback)

        # Light curve samples
        self.flux_plot = [None for i in range(self.nmaps)]
        self.flux_source = [
            ColumnDataSource(data=dict(
                xs=[np.linspace(0, 2, npts) for j in range(6)],
                ys=[fluxnorm(self.A_F[j] @ self.ylm[i]) for j in range(6)],
                color=[Plasma6[5 - j] for j in range(6)],
                inc=[15, 30, 45, 60, 75, 90],
            )) for i in range(self.nmaps)
        ]
        for i in range(self.nmaps):
            self.flux_plot[i] = figure(
                toolbar_location=None,
                x_range=(0, 2),
                y_range=None,
                min_border_left=50,
                plot_height=400,
            )
            if i == 0:
                self.flux_plot[i].yaxis.axis_label = "flux [ppt]"
                self.flux_plot[i].yaxis.axis_label_text_font_style = "normal"
            self.flux_plot[i].xaxis.axis_label = "rotational phase"
            self.flux_plot[i].xaxis.axis_label_text_font_style = "normal"
            self.flux_plot[i].outline_line_color = None
            self.flux_plot[i].multi_line(
                xs="xs",
                ys="ys",
                line_color="color",
                source=self.flux_source[i],
            )
            self.flux_plot[i].toolbar.active_drag = None
            self.flux_plot[i].toolbar.active_scroll = None
            self.flux_plot[i].toolbar.active_tap = None
            self.flux_plot[i].yaxis.major_label_orientation = np.pi / 4
            self.flux_plot[i].xaxis.axis_label_text_font_size = "8pt"
            self.flux_plot[i].xaxis.major_label_text_font_size = "8pt"
            self.flux_plot[i].yaxis.axis_label_text_font_size = "8pt"
            self.flux_plot[i].yaxis.major_label_text_font_size = "8pt"

        # Javascript callback to update light curves & images
        self.A_F_source = ColumnDataSource(data=dict(A_F=self.A_F))
        self.A_I_source = ColumnDataSource(data=dict(A_I=self.A_I))
        self.ylm_source = ColumnDataSource(data=dict(ylm=self.ylm))
        callback = CustomJS(
            args=dict(
                A_F_source=self.A_F_source,
                A_I_source=self.A_I_source,
                ylm_source=self.ylm_source,
                flux_source=self.flux_source,
                moll_source=self.moll_source,
            ),
            code="""
            var A_F = A_F_source.data['A_F'];
            var A_I = A_I_source.data['A_I'];
            var ylm = ylm_source.data['ylm'];
            var i, j, k, l, m, n;
            for (n = 0; n < {nmax}; n++) {{

                // Update the light curves
                var flux = flux_source[n].data['ys'];
                for (l = 0; l < {lmax}; l++) {{
                    for (m = 0; m < {mmax}; m++) {{
                        flux[l][m] = 0.0;
                        for (k = 0; k < {kmax}; k++) {{
                            flux[l][m] += A_F[{kmax} * ({mmax} * l + m) + k] * ylm[{kmax} * n + k];
                        }}
                    }}
                    // Normalize
                    var mean = flux[l].reduce((previous, current) => current += previous) / {mmax};
                    for (m = 0; m < {mmax}; m++) {{
                        flux[l][m] = 1e3 * ((1 + flux[l][m]) / (1 + mean) - 1)
                    }}
                }}
                flux_source[n].change.emit();

                // Update the images
                var image = moll_source[n].data['image'][0];
                for (i = 0; i < {imax}; i++) {{
                    for (j = 0; j < {jmax}; j++) {{
                        image[{jmax} * i + j] = 1.0;
                        for (k = 0; k < {kmax}; k++) {{
                            image[{jmax} * i + j] += A_I[{kmax} * ({jmax} * i + j) + k] * ylm[{kmax} * n + k];
                        }}
                    }}
                }}
                moll_source[n].change.emit();

            }}
            """.format(
                imax=self.npix,
                jmax=2 * self.npix,
                kmax=(self.ydeg + 1)**2,
                nmax=self.nmaps,
                lmax=self.A_F.shape[0],
                mmax=self.npts,
            ),
        )
        self.js_dummy = self.flux_plot[0].circle(x=0, y=0, size=1, alpha=0)
        self.js_dummy.glyph.js_on_change("size", callback)

        # Full layout
        self.plots = row(
            *[
                column(m, f, sizing_mode="scale_both")
                for m, f in zip(self.moll_plot, self.flux_plot)
            ],
            margin=(10, 30, 10, 30),
            sizing_mode="scale_both",
            css_classes=["samples"],
        )
        self.layout = grid([[self.plots]])

    def slider_callback(self, attr, old, new):
        self.color_mapper.low, self.color_mapper.high = self.slider.value

    def seed_callback(self, event):
        self.gp.random.seed(np.random.randint(0, 99999))
        tmp = self.auto_button.active
        self.callback(None, None, None)
        self.auto_button.active = tmp

    def reset_callback(self, event):
        self.Size.sliders[0].value = params["size"]["r"]["value"]
        self.Latitude.sliders[0].value = params["latitude"]["mu"]["value"]
        self.Latitude.sliders[1].value = params["latitude"]["sigma"]["value"]
        self.Contrast.sliders[0].value = params["contrast"]["c"]["value"]
        self.Contrast.sliders[1].value = params["contrast"]["n"]["value"]
        self.Size.callback(None, None, None)
        self.Latitude.callback(None, None, None)
        self.gp.random.seed(0)
        self.slider.value = (0.5, 1.2)
        self.slider_callback(None, None, None)
        self.smooth_button.active = False
        self.smooth_callback(None)
        self.auto_button.active = True
        self.callback(None, None, None)

    def smooth_callback(self, event):
        if self.smooth_button.active:
            self.Latitude.throttle_time = self.throttle_time
            self.Size.throttle_time = self.throttle_time
            self.Contrast.throttle_time = self.throttle_time
        else:
            self.Latitude.throttle_time = 0
            self.Size.throttle_time = 0
            self.Contrast.throttle_time = 0

    def callback(self, attr, old, new):
        try:

            if self.auto_button.active:
                self.gp.random.seed(np.random.randint(0, 99999))

            # Draw the samples
            self.ylm = self.sample_ylm(
                self.Size.sliders[0].value,
                self.Latitude.sliders[0].value,
                self.Latitude.sliders[1].value,
                self.Contrast.sliders[0].value,
                self.Contrast.sliders[1].value,
            )[0]

            # HACK: Trigger the JS callback by modifying
            # a property of the `js_dummy` glyph
            self.ylm_source.data["ylm"] = self.ylm
            self.js_dummy.glyph.size = 3 - self.js_dummy.glyph.size

            # If everything worked, ensure the sliders are active
            for slider in (self.Size.sliders + self.Latitude.sliders +
                           self.Contrast.sliders):
                slider.bar_color = "white"

        except Exception as e:

            # Something went wrong inverting the covariance!
            for slider in (self.Size.sliders + self.Latitude.sliders +
                           self.Contrast.sliders):
                slider.bar_color = "firebrick"

            print("An error occurred when computing the covariance:")
            print(e)
示例#12
0
class Controls:

    simulation_controls = [
        Div(text="""<strong>Simulation Controls</strong>""")
    ]

    disease_controls = [
        Div(text="""<strong>Disease Profile Controls</strong>""")
    ]

    response_controls = [
        Div(text="""<strong>Disease Response Controls</strong>""")
    ]

    def __init__(self, controller):
        # Simulation controls
        self.agents = wrap(Slider(start=PARAMETERS['agents'][0], end=PARAMETERS['agents'][2],
                                  value=controller.params['agents'], step=1, title="Number of agents"), controller, 'agents')
        self.initial_immunity = wrap(Slider(start=PARAMETERS['initial_immunity'][0], end=PARAMETERS['initial_immunity'][2],
                                            value=controller.params['initial_immunity'] * 100, step=1, title="Initial immunity (%)"), controller, 'initial_immunity')

        self.simulation_controls.append(self.agents)
        self.simulation_controls.append(self.initial_immunity)

        # Disease controls
        self.sickness_proximity = wrap(Slider(start=PARAMETERS['sickness_proximity'][0], end=PARAMETERS['sickness_proximity'][2],
                                              value=controller.params['sickness_proximity'], step=1, title="Sickness proximity"), controller, "sickness_proximity")
        self.sickness_duration = wrap(Slider(start=PARAMETERS['sickness_duration'][0], end=PARAMETERS['sickness_duration'][2],
                                             value=controller.params['sickness_duration'], step=1, title="Sickness duration (ticks)"), controller, "sickness_duration")
        self.disease_controls.append(self.sickness_proximity)
        self.disease_controls.append(self.sickness_duration)

        # Response controls
        self.distancing_factor = wrap(Slider(start=PARAMETERS['distancing_factor'][0], end=PARAMETERS['distancing_factor'][2],
                                             value=controller.params['distancing_factor'] * 100, step=0.5, title="Physical distancing factor (%)"), controller, "distancing_factor")
        self.quarantine_delay = wrap(Slider(start=PARAMETERS['quarantine_delay'][0], end=PARAMETERS['quarantine_delay'][2],
                                            value=controller.params['quarantine_delay'], step=1, title="Quarantine delay (ticks)"), controller, "quarantine_delay")

        self.quarantine_toggle = Toggle(label="Quarantine enabled" if controller.params['quarantining'] else "Quarantine disabled",
                                        button_type="success" if controller.params['quarantining'] else "danger", active=controller.params['quarantining'])

        self.response_controls.append(self.distancing_factor)
        self.response_controls.append(self.quarantine_delay)
        self.response_controls.append(self.quarantine_toggle)

        def toggle_callback(event):
            controller.update_parameter(
                'quarantining', not controller.params['quarantining'])

            if controller.params['quarantining']:
                self.quarantine_toggle.label = "Quarantine enabled"
                self.quarantine_toggle.button_type = "success"

                controller.update_parameter(
                    'quarantine_delay', controller.params['quarantine_delay'])
            else:
                self.quarantine_toggle.label = "Quarantine disabled"
                self.quarantine_toggle.button_type = "danger"

            controller.reset()

        self.quarantine_toggle.on_click(toggle_callback)

    def get_controls(self):
        return column(*self.simulation_controls, *self.disease_controls, *self.response_controls)
示例#13
0
class PlayerWidget:
    def __init__(self, parent=None):
        self.parent = parent
        self.fs = main_config['fs']
        self.n_channels = main_config['n_channels']
        self.t0 = 0
        self.last_ts = 0

        # Game log reader (separate thread)
        self.game_logs_path = game_config['game_logs_path']
        self.game_log_reader = None
        self._expected_action = (0, 'Rest')
        self.thread_log = QtCore.QThreadPool()

        # Game window (separate process)
        clean_log_directory(self.game_logs_path)
        self.game = None

        # Port event sender
        self.micro_path = main_config['micro_path']
        self.port_sender = None

        # Game player
        self.game_path = game_config['game_path']
        self.player_idx = game_config['player_idx']
        self.game_start_time = None

        # Chronogram
        self.chrono_source = ColumnDataSource(dict(ts=[],
                                                   y_true=[],
                                                   y_pred=[]))
        self.pred_decoding = main_config['pred_decoding']

        # LSL stream reader
        self.lsl_reader = None
        self.lsl_start_time = None
        self.thread_lsl = QtCore.QThreadPool()
        self.channel_source = ColumnDataSource(dict(ts=[], eeg=[]))
        self._lsl_data = (None, None)

        # LSL stream recorder
        if not os.path.isdir(main_config['record_path']):
            os.mkdir(main_config['record_path'])
        self.record_path = main_config['record_path']
        self.record_name = game_config['record_name']
        self.lsl_recorder = None

        # Predictor
        self.models_path = main_config['models_path']
        self.input_signal = np.zeros((self.n_channels, 4 * self.fs))
        self.predictor = None
        self.thread_pred = QtCore.QThreadPool()
        self._pred_action = (0, 'Rest')

    @property
    def lsl_data(self):
        return self._lsl_data

    @lsl_data.setter
    def lsl_data(self, data):
        self._lsl_data = data

        # Memorize the most recent timestamp
        ts, eeg = data
        self.last_ts = ts[-1]

        # Record signal
        if self.lsl_recorder is not None:
            self.lsl_recorder.save_data(copy.deepcopy(ts), copy.deepcopy(eeg))

        self.parent.add_next_tick_callback(self.update_signal)

    @property
    def pred_action(self):
        return self._pred_action

    @pred_action.setter
    def pred_action(self, val_tuple):
        self._pred_action = val_tuple
        if self.game_start_time is not None:
            self.parent.add_next_tick_callback(self.update_prediction)

    @property
    def expected_action(self):
        return self._expected_action

    @expected_action.setter
    def expected_action(self, action):
        logging.info(f'Receiving groundtruth from logs: {action}')
        self._expected_action = copy.deepcopy(action)

        # In autoplay, we directly update the model prediction (no delay)
        if self.modelfile == 'AUTOPLAY':
            self._pred_action = copy.deepcopy(action)

        self.parent.add_next_tick_callback(self.update_groundtruth)

    @property
    def available_logs(self):
        logs = list(self.game_logs_path.glob(game_config['game_logs_pattern']))
        return sorted(logs)

    @property
    def game_is_on(self):
        if self.game is not None:
            # Poll returns None when game process is running and 0 otherwise
            return self.game.poll() is None
        else:
            return False

    @property
    def should_record(self):
        return 'Record' in self.selected_settings

    @property
    def available_models(self):
        ml_models = [p.name for p in self.models_path.glob('*.pkl')]
        dl_models = [p.name for p in self.models_path.glob('*.h5')]
        return ['AUTOPLAY'] + ml_models + dl_models

    @property
    def model_name(self):
        return self.select_model.value

    @property
    def modelfile(self):
        if self.select_model.value == 'AUTOPLAY':
            return 'AUTOPLAY'
        else:
            return self.models_path / self.select_model.value

    @property
    def is_convnet(self):
        return self.select_model.value.split('.')[-1] == 'h5'

    @property
    def available_ports(self):
        if sys.platform == 'linux':
            ports = self.micro_path.glob('*')
            return [''] + [p.name for p in ports]
        elif sys.platform == 'win32':
            return [''] + [p.device for p in serial.tools.list_ports.comports()]

    @property
    def sending_events(self):
        return 'Send events' in self.selected_settings

    @property
    def channel_idx(self):
        return int(self.select_channel.value.split('-')[0])

    @property
    def selected_settings(self):
        active = self.checkbox_settings.active
        return [self.checkbox_settings.labels[i] for i in active]

    @property
    def accuracy(self):
        y_pred = self.chrono_source.data['y_pred']
        y_true = self.chrono_source.data['y_true']
        return accuracy_score(y_true, y_pred)

    def reset_lsl(self):
        if self.lsl_reader:
            self.lsl_reader.should_stream = False
            self.lsl_reader = None
            self.lsl_start_time = None
            self.thread_lsl.clear()

    def reset_predictor(self):
        if self.predictor:
            self.predictor.should_predict = False
            self.predictor = None
            self.thread_pred.clear()

    def reset_recorder(self):
        if self.lsl_recorder:
            self.lsl_recorder.close_h5()
            self.lsl_recorder = None

    def reset_plots(self):
        self.chrono_source.data = dict(ts=[], y_pred=[], y_true=[])
        self.channel_source.data = dict(ts=[], eeg=[])
        self.gd_info.text = ''
        self.pred_info.text = ''
        self.acc_info.text = ''

    def reset_game(self):
        self.game.kill()
        self.game = None

    def reset_log_reader(self):
        logging.info('Delete old log reader')
        self.thread_log.clear()
        self.game_log_reader = None

    def on_model_change(self, attr, old, new):
        logging.info(f'Select new pre-trained model {new}')
        self.select_model.options = self.available_models
        self.model_info.text = f'<b>Model:</b> {new}'
        self.parent.add_next_tick_callback(self.start_predictor_thread)

    def on_select_port(self, attr, old, new):
        logging.info(f'Select new port: {new}')

        if self.port_sender is not None:
            logging.info('Delete old log reader')
            self.port_sender = None

        logging.info(f'Instanciate port sender {new}')
        self.port_sender = CommandSenderPort(new)

    def on_channel_change(self, attr, old, new):
        logging.info(f'Select new channel {new}')
        self.channel_source.data['eeg'] = []
        self.plot_stream.yaxis.axis_label = f'Amplitude ({new})'

    def on_settings_change(self, attr, old, new):
        self.plot_stream.visible = 0 in new

    def start_game_process(self):
        logging.info('Lauching Cybathlon game')
        self.n_old_logs = len(self.available_logs)
        self.reset_plots()

        # Close any previous game process
        if self.game is not None:
            self.reset_game()

        self.game = subprocess.Popen(str(self.game_path),
                                     stdin=subprocess.PIPE,
                                     stdout=subprocess.PIPE,
                                     stderr=subprocess.PIPE,
                                     text=True)
        assert self.game is not None, 'Can\'t launch game !'

    def start_log_reader(self):
        # Check if log reader already instanciated
        if self.game_log_reader is not None:
            self.reset_log_reader()

        # Wait for new logfile to be created
        while not len(self.available_logs) - self.n_old_logs > 0:
            logging.info('Waiting for new race logs...')
            time.sleep(0.5)
        log_filename = str(self.available_logs[-1])

        # Log reader is started in a separate thread
        logging.info(f'Instanciate log reader {log_filename}')
        self.game_log_reader = GameLogReader(self, log_filename,
                                             self.player_idx)
        self.thread_log.start(self.game_log_reader)

    def on_launch_game_start(self):
        self.button_launch_game.label = 'Lauching...'
        self.button_launch_game.button_type = 'warning'
        self.parent.add_next_tick_callback(self.on_launch_game)

    def on_launch_game(self):
        self.start_game_process()
        self.start_log_reader()
        self.button_launch_game.label = 'Launched'
        self.button_launch_game.button_type = 'success'

    def update_groundtruth(self):
        action_idx, action_name = self.expected_action

        # Start autoplay predictor when game starts + reset chronogram (if multiple consecutive runs)
        if action_name == 'Game start':
            self.reset_plots()
            self.game_start_time = time.time()
            if self.modelfile == 'AUTOPLAY':
                self.parent.add_next_tick_callback(self.start_predictor_thread)
        elif action_name in ['Game end', 'Pause']:
            self.reset_predictor()
        elif action_name == 'Resume':
            self.parent.add_next_tick_callback(self.start_predictor_thread)
        elif action_name == 'Reset game':
            self.reset_plots()
            self.reset_predictor()
            self.reset_log_reader()
            self.parent.add_next_tick_callback(self.start_log_reader)

        # Send groundtruth to microcontroller
        if self.sending_events:
            if self.port_sender is not None:
                self.port_sender.sendCommand(action_idx)
                logging.info(f'Send event: {action_idx}')
            else:
                logging.info('Please select a port !')

    def update_prediction(self):
        if not self.game_is_on:
            logging.info('Game window was closed')
            self.button_launch_game.label = 'Launch Game'
            self.button_launch_game.button_type = 'primary'
            self.select_model.value = 'AUTOPLAY'
            self.reset_predictor()
            return

        groundtruth = self.expected_action[0]
        action_idx = self.pred_action[0]

        # Save groundtruth as event
        if self.lsl_recorder is not None:
            marker_id = int(f'{(groundtruth+1)*2}{(action_idx+1)*2}')
            self.lsl_recorder.save_event(self.last_ts, marker_id)

        # Update chronogram source
        ts = time.time() - self.game_start_time
        self.chrono_source.stream(dict(ts=[ts],
                                       y_true=[groundtruth],
                                       y_pred=[action_idx]))

        # Update information display
        self.gd_info.text = f'<b>Groundtruth:</b> {self.expected_action}'
        self.pred_info.text = f'<b>Prediction:</b> {self.pred_action}'
        self.acc_info.text = f'<b>Accuracy:</b> {self.accuracy:.2f}'

    def on_lsl_connect_toggle(self, active):
        if active:
            # Connect to LSL stream
            self.button_lsl.label = 'Seaching...'
            self.button_lsl.button_type = 'warning'
            self.parent.add_next_tick_callback(self.start_lsl_thread)
        else:
            self.reset_lsl()
            self.button_lsl.label = 'LSL Disconnected'
            self.button_lsl.button_type = 'danger'

    def start_lsl_thread(self):
        try:
            self.lsl_reader = LSLClient(self)
            if self.lsl_reader is not None:
                self.select_channel.options = [f'{i+1} - {ch}' for i, ch
                                               in enumerate(self.lsl_reader.ch_names)]
                self.thread_lsl.start(self.lsl_reader)
                self.button_lsl.label = 'Reading LSL stream'
                self.button_lsl.button_type = 'success'
        except Exception:
            logging.info(f'No LSL stream - {traceback.format_exc()}')
            self.button_lsl.label = 'Can\'t find stream'
            self.button_lsl.button_type = 'danger'
            self.reset_lsl()

    def start_predictor_thread(self):
        self.reset_predictor()

        try:
            self.predictor = ActionPredictor(self,
                                             self.modelfile,
                                             self.is_convnet)
            self.thread_pred.start(self.predictor)
        except Exception as e:
            logging.error(f'Failed loading model {self.modelfile} - {e}')
            self.select_model.value = 'AUTOPLAY'
            self.reset_predictor()

    def on_lsl_record_toggle(self, active):
        if active:
            try:
                self.lsl_recorder = LSLRecorder(self.record_path,
                                                self.record_name,
                                                self.lsl_reader.ch_names)
                self.lsl_recorder.open_h5()
                self.button_record.label = 'Stop recording'
                self.button_record.button_type = 'success'
            except Exception:
                self.reset_recorder()
                self.button_record.label = 'Recording failed'
                self.button_record.button_type = 'danger'
        else:
            self.reset_recorder()
            self.button_record.label = 'Start recording'
            self.button_record.button_type = 'primary'

    def update_signal(self):
        ts, eeg = self.lsl_data

        if ts.shape[0] != eeg.shape[-1]:
            logging.info('Skipping data points (bad format)')
            return

        # Convert timestamps in seconds
        if self.lsl_start_time is None:
            self.lsl_start_time = time.time()
            self.t0 = ts[0]

        # Update source display
        ch = self.channel_idx
        self.channel_source.stream(dict(ts=ts-self.t0, eeg=eeg[ch, :]),
                                   rollover=int(2 * self.fs))

        # Update signal
        chunk_size = eeg.shape[-1]
        self.input_signal = np.roll(self.input_signal, -chunk_size, axis=-1)
        self.input_signal[:, -chunk_size:] = eeg

    def create_widget(self):
        # Button - Launch Cybathlon game in new window
        self.button_launch_game = Button(label='Launch Game',
                                         button_type='primary')
        self.button_launch_game.on_click(self.on_launch_game_start)

        # Toggle - Connect to LSL stream
        self.button_lsl = Toggle(label='Connect to LSL')
        self.button_lsl.on_click(self.on_lsl_connect_toggle)

        # Toggle - Start/stop LSL stream recording
        self.button_record = Toggle(label='Start Recording',
                                    button_type='primary')
        self.button_record.on_click(self.on_lsl_record_toggle)

        # Select - Choose pre-trained model
        self.select_model = Select(title="Select pre-trained model",
                                   value='AUTOPLAY',
                                   options=self.available_models)
        self.select_model.on_change('value', self.on_model_change)

        # Select - Choose port to send events to
        self.select_port = Select(title='Select port')
        self.select_port.options = self.available_ports
        self.select_port.on_change('value', self.on_select_port)

        # Checkbox - Choose player settings
        self.div_settings = Div(text='<b>Settings</b>', align='center')
        self.checkbox_settings = CheckboxButtonGroup(labels=['Show signal',
                                                             'Send events'])
        self.checkbox_settings.on_change('active', self.on_settings_change)

        # Select - Channel to visualize
        self.select_channel = Select(title='Select channel', value='1 - Fp1')
        self.select_channel.on_change('value', self.on_channel_change)

        # Plot - LSL EEG Stream
        self.plot_stream = figure(title='Temporal EEG signal',
                                  x_axis_label='Time [s]',
                                  y_axis_label='Amplitude',
                                  plot_height=500,
                                  plot_width=800,
                                  visible=False)
        self.plot_stream.line(x='ts', y='eeg', source=self.channel_source)

        # Plot - Chronogram prediction vs results
        self.plot_chronogram = figure(title='Chronogram',
                                      x_axis_label='Time [s]',
                                      y_axis_label='Action',
                                      plot_height=300,
                                      plot_width=800)
        self.plot_chronogram.line(x='ts', y='y_true', color='blue',
                                  source=self.chrono_source,
                                  legend_label='Groundtruth')
        self.plot_chronogram.cross(x='ts', y='y_pred', color='red',
                                   source=self.chrono_source,
                                   legend_label='Prediction')
        self.plot_chronogram.legend.background_fill_alpha = 0.6
        self.plot_chronogram.yaxis.ticker = list(self.pred_decoding.keys())
        self.plot_chronogram.yaxis.major_label_overrides = self.pred_decoding

        # Div - Display useful information
        self.model_info = Div(text=f'<b>Model:</b> AUTOPLAY')
        self.pred_info = Div()
        self.gd_info = Div()
        self.acc_info = Div()

        # Create layout
        column1 = column(self.button_launch_game, self.button_lsl,
                         self.button_record, self.select_model,
                         self.select_port, self.select_channel,
                         self.div_settings, self.checkbox_settings)
        column2 = column(self.plot_stream, self.plot_chronogram)
        column3 = column(self.model_info, self.gd_info,
                         self.pred_info, self.acc_info)
        return row(column1, column2, column3)
def turn_decision_off(new):
    """
    turn decision text on/off
    """
    if new:
        p.select(name="decision_text").visible = True
        best_root_plot.select(name="decision_text").visible = True
        decision_button.label = "Sonuç gösterme"
    else:
        p.select(name="decision_text").visible = False
        best_root_plot.select(name="decision_text").visible = False
        decision_button.label = "Sonuç göster"


decision_button.on_click(turn_decision_off)


def turn_arrow_labels_off(new):
    """
    turn arrow labels on/off
    """
    if new:
        p.select(name="arrowLabels").visible = True
        best_root_plot.select(name="arrowLabels").visible = True
        arrow_button.label = "Karar değerlerini gösterme"
    else:
        p.select(name="arrowLabels").visible = False
        best_root_plot.select(name="arrowLabels").visible = False
        arrow_button.label = "Karar değerlerini göster"
示例#15
0
class WarmUpWidget:
    def __init__(self, parent=None):
        self.parent = parent
        self.fs = main_config['fs']
        self.n_channels = main_config['n_channels']
        self.t0 = 0
        self.last_ts = 0
        self.game_is_on = False

        # Chronogram
        self.chrono_source = ColumnDataSource(dict(ts=[], y_pred=[]))
        self.pred_decoding = main_config['pred_decoding']

        # LSL stream reader
        self.lsl_reader = None
        self.lsl_start_time = None
        self._lsl_data = (None, None)
        self.thread_lsl = QtCore.QThreadPool()
        self.channel_source = ColumnDataSource(dict(ts=[], eeg=[]))
        self.buffer_size_s = 10

        # LSL stream recorder
        if not os.path.isdir(main_config['record_path']):
            os.mkdir(main_config['record_path'])
        self.record_path = main_config['record_path']
        self.record_name = warmup_config['record_name']
        self.lsl_recorder = None

        # Predictor
        self.models_path = main_config['models_path']
        self.input_signal = np.zeros((self.n_channels, 4 * self.fs))
        self.predictor = None
        self.thread_pred = QtCore.QThreadPool()
        self._pred_action = (0, 'Rest')

        # Feedback images
        self.static_folder = warmup_config['static_folder']
        self.action2image = warmup_config['action2image']

    @property
    def pred_action(self):
        return self._pred_action

    @pred_action.setter
    def pred_action(self, val_tuple):
        self._pred_action = val_tuple
        self.parent.add_next_tick_callback(self.update_prediction)

    @property
    def lsl_data(self):
        return self._lsl_data

    @lsl_data.setter
    def lsl_data(self, data):
        self._lsl_data = data
        self.parent.add_next_tick_callback(self.update_signal)

    @property
    def available_models(self):
        ml_models = [p.name for p in self.models_path.glob('*.pkl')]
        dl_models = [p.name for p in self.models_path.glob('*.h5')]
        return [''] + ml_models + dl_models

    @property
    def selected_settings(self):
        active = self.checkbox_settings.active
        return [self.checkbox_settings.labels[i] for i in active]

    @property
    def modelfile(self):
        return self.models_path / self.select_model.value

    @property
    def model_name(self):
        return self.select_model.value

    @property
    def is_convnet(self):
        return self.select_model.value.split('.')[-1] == 'h5'

    @property
    def channel_idx(self):
        return int(self.select_channel.value.split('-')[0])

    def reset_lsl(self):
        if self.lsl_reader:
            self.lsl_reader.should_stream = False
            self.lsl_reader = None
            self.lsl_start_time = None
            self.thread_lsl.clear()

    def reset_predictor(self):
        self.model_info.text = f'<b>Model:</b> None'
        self.pred_info.text = f'<b>Prediction:</b> None'
        self.image.text = ''
        if self.predictor:
            self.predictor.should_predict = False
            self.predictor = None
            self.thread_pred.clear()

    def reset_recorder(self):
        if self.lsl_recorder:
            self.lsl_recorder.close_h5()
            self.lsl_recorder = None

    def on_settings_change(self, attr, old, new):
        self.plot_stream.visible = 0 in new

    def on_model_change(self, attr, old, new):
        logging.info(f'Select new pre-trained model {new}')
        self.select_model.options = self.available_models

        # Delete existing predictor thread
        if self.predictor is not None:
            self.reset_predictor()
            if new == '':
                return

        try:
            self.predictor = ActionPredictor(self, self.modelfile,
                                             self.is_convnet)
            self.thread_pred.start(self.predictor)
            self.model_info.text = f'<b>Model:</b> {new}'
        except Exception as e:
            logging.error(f'Failed loading model {self.modelfile} - {e}')
            self.reset_predictor()

    def on_channel_change(self, attr, old, new):
        logging.info(f'Select new channel {new}')
        self.channel_source.data = dict(ts=[], eeg=[])
        self.plot_stream.yaxis.axis_label = f'Amplitude ({new})'

    def reset_plots(self):
        self.chrono_source.data = dict(ts=[], y_pred=[])
        self.channel_source.data = dict(ts=[], eeg=[])

    def update_prediction(self):
        # Update chronogram source
        action_idx = self.pred_action[0]
        if self.lsl_start_time is not None:
            ts = time.time() - self.lsl_start_time
            self.chrono_source.stream(dict(ts=[ts], y_pred=[action_idx]))

        # Update information display (might cause delay)
        self.pred_info.text = f'<b>Prediction:</b> {self.pred_action}'
        src = self.static_folder / \
            self.action2image[self.pred_decoding[action_idx]]
        self.image.text = f"<img src={src} width='200' height='200' text-align='center'>"

        # Save prediction as event
        if self.lsl_recorder is not None:
            self.lsl_recorder.save_event(copy.deepcopy(self.last_ts),
                                         copy.deepcopy(action_idx))

    def on_lsl_connect_toggle(self, active):
        if active:
            # Connect to LSL stream
            self.button_lsl.label = 'Seaching...'
            self.button_lsl.button_type = 'warning'
            self.reset_plots()
            self.parent.add_next_tick_callback(self.start_lsl_thread)
        else:
            self.reset_lsl()
            self.reset_predictor()
            self.button_lsl.label = 'LSL Disconnected'
            self.button_lsl.button_type = 'danger'

    def start_lsl_thread(self):
        try:
            self.lsl_reader = LSLClient(self)
            self.fs = self.lsl_reader.fs

            if self.lsl_reader is not None:
                self.select_channel.options = [
                    f'{i+1} - {ch}'
                    for i, ch in enumerate(self.lsl_reader.ch_names)
                ]
                self.thread_lsl.start(self.lsl_reader)
                self.button_lsl.label = 'Reading LSL stream'
                self.button_lsl.button_type = 'success'
        except Exception:
            logging.info(f'No LSL stream - {traceback.format_exc()}')
            self.button_lsl.label = 'Can\'t find stream'
            self.button_lsl.button_type = 'danger'
            self.reset_lsl()

    def on_lsl_record_toggle(self, active):
        if active:
            try:
                self.lsl_recorder = LSLRecorder(self.record_path,
                                                self.record_name,
                                                self.lsl_reader.ch_names)
                self.lsl_recorder.open_h5()
                self.button_record.label = 'Stop recording'
                self.button_record.button_type = 'success'
            except Exception as e:
                logging.info(f'Failed creating LSLRecorder - {e}')
                self.reset_recorder()
                self.button_record.label = 'Recording failed'
                self.button_record.button_type = 'danger'
        else:
            self.reset_recorder()
            self.button_record.label = 'Start recording'
            self.button_record.button_type = 'primary'

    def update_signal(self):
        ts, eeg = self.lsl_data
        self.last_ts = ts[-1]

        if ts.shape[0] != eeg.shape[-1]:
            logging.info('Skipping data points (bad format)')
            return

        # Local LSL start time
        if self.lsl_start_time is None:
            self.lsl_start_time = time.time()
            self.t0 = ts[0]

        # Update source display
        ch = self.channel_idx
        self.channel_source.stream(dict(ts=(ts - self.t0) / self.fs,
                                        eeg=eeg[ch, :]),
                                   rollover=int(self.buffer_size_s * self.fs))

        # Update signal
        chunk_size = eeg.shape[-1]
        self.input_signal = np.roll(self.input_signal, -chunk_size, axis=-1)
        self.input_signal[:, -chunk_size:] = eeg

        # Record signal
        if self.lsl_recorder is not None:
            self.lsl_recorder.save_data(copy.deepcopy(ts), copy.deepcopy(eeg))

    def create_widget(self):
        # Toggle - Connect to LSL stream
        self.button_lsl = Toggle(label='Connect to LSL')
        self.button_lsl.on_click(self.on_lsl_connect_toggle)

        # Toggle - Start/stop LSL stream recording
        self.button_record = Toggle(label='Start Recording',
                                    button_type='primary')
        self.button_record.on_click(self.on_lsl_record_toggle)

        # Select - Choose pre-trained model
        self.select_model = Select(title="Select pre-trained model")
        self.select_model.options = self.available_models
        self.select_model.on_change('value', self.on_model_change)

        # Checkbox - Choose settings
        self.div_settings = Div(text='<b>Settings</b>', align='center')
        self.checkbox_settings = CheckboxButtonGroup(labels=['Show signal'])
        self.checkbox_settings.on_change('active', self.on_settings_change)

        # Select - Channel to visualize
        self.select_channel = Select(title='Select channel', value='1 - Fp1')
        self.select_channel.on_change('value', self.on_channel_change)

        # Plot - LSL EEG Stream
        self.plot_stream = figure(title='Temporal EEG signal',
                                  x_axis_label='Time [s]',
                                  y_axis_label='Amplitude',
                                  plot_height=500,
                                  plot_width=800,
                                  visible=False)
        self.plot_stream.line(x='ts', y='eeg', source=self.channel_source)

        # Plot - Chronogram prediction vs results
        self.plot_chronogram = figure(title='Chronogram',
                                      x_axis_label='Time [s]',
                                      y_axis_label='Action',
                                      plot_height=300,
                                      plot_width=800)
        self.plot_chronogram.cross(x='ts',
                                   y='y_pred',
                                   color='red',
                                   source=self.chrono_source,
                                   legend_label='Prediction')
        self.plot_chronogram.legend.background_fill_alpha = 0.6
        self.plot_chronogram.yaxis.ticker = list(self.pred_decoding.keys())
        self.plot_chronogram.yaxis.major_label_overrides = self.pred_decoding

        # Div - Display useful information
        self.model_info = Div(text=f'<b>Model:</b> None')
        self.pred_info = Div(text=f'<b>Prediction:</b> None')
        self.image = Div()

        # Create layout
        column1 = column(self.button_lsl, self.button_record,
                         self.select_model, self.select_channel,
                         self.div_settings, self.checkbox_settings)
        column2 = column(self.plot_stream, self.plot_chronogram)
        column3 = column(self.model_info, self.pred_info, self.image)
        return row(column1, column2, column3)
示例#16
0
def run(new):
    global p, patches, colors, counter

    for _ in range(slider.value):
        counter += 1
        data = patches.data_source.data.copy()
        rates = np.random.uniform(0, 100, size=100).tolist()
        color = [colors[2 + int(rate / 16.667)] for rate in rates]

        p.title = 'Algorithms Deployed, Iteration: {}'.format(counter)
        source.data['rate'] = rates
        source.data['color'] = color
        time.sleep(5)

toggle = Toggle(label='START')
toggle.on_click(run)

slider = Slider(name='N iterations to advance',
                title='N iterations to advance',
                start=5,
                end=10000,
                step=5,
                value=500)

# set up layout
toggler = HBox(toggle)
inputs = VBox(toggler, slider)

# add to document
curdoc().add_root(HBox(inputs))
示例#17
0
def addHPC(clicked):
    global HPC
    HPC = clicked
    if HPC and ('Nuclear' not in showcols):
        showcols.append('Nuclear')
    if HPC:
        showcols.append('Nuclear with HPC')
        HPCButton.button_type = 'success'
    elif 'Nuclear with HPC' in showcols:
        showcols.remove('Nuclear with HPC')
    if not HPC:
        HPCButton.button_type = 'default'
    showLines(showcols)


HPCButton.on_click(addHPC)

lagoonButton = Toggle(label='Add tidal lagoon', button_type='default')


def addLagoon(clicked):
    global lagoon
    lagoon = clicked
    if clicked:
        lagoonButton.button_type = 'success'
        if 'All other' not in showcols:
            showcols.append('All other')
        showcols.append('All other - including lagoon')
    else:
        lagoonButton.button_type = 'default'
        try:
示例#18
0
button.on_click(callback)

# add a button widget and configure with the call back
button2 = Button(label="Press Me too!")
button2.on_click(callback2)

button3 = Button(label="Display Sliders Demo")
button3.on_click(callback3)

doc = curdoc()

button4 = Toggle(label="Settings", button_type="success")

# put the button and plot in a layout and add to the document
doc.add_root(row(column(button, p1), column(button2, p2)))
doc.add_root(row(column(button3)))
secret_row_1 = row(column(button4, inputs, plot, width=400))

child = secret_row_1.children[0].children.pop(1)


def callback4(toggled):
    global child
    if not toggled:
        secret_row_1.children[0].children.pop(1)
    else:
        secret_row_1.children[0].children.insert(1, child)


button4.on_click(callback4)

def reload_plot(attr, old, new):
    global_plot.children = [add_renderers()]


def remove_background(_):
    if not toggle_background_button.active:
        toggle_background_button.button_type = 'danger'
    else:
        toggle_background_button.button_type = 'success'
    global_plot.children = [add_renderers()]


color_from_dropdown.on_change('value', reload_plot)
toggle_background_button.on_click(remove_background)

global_plot.children = [add_renderers()]
point_info = Div()
point_probabilities = Row(children=[])


def get_attributes_cb(attr, old, new):
    get_attributes(source, df, radius_source, radius_slider, point_info,
                   point_probabilities)


radius_slider.on_change('value', get_attributes_cb)

source.on_change('selected', get_attributes_cb)
示例#20
0
    def __init__(self, sv_rt):
        """Initialize a stream control widget.
        """
        doc = curdoc()
        self.receiver = doc.receiver
        self.stats = doc.stats
        self.jf_adapter = doc.jf_adapter
        self._sv_rt = sv_rt

        # connect toggle button
        def toggle_callback(_active):
            if _active or not self._prev_image_buffer:
                self.prev_image_slider.disabled = True
            else:
                self.prev_image_slider.disabled = False

            self._update_toggle_view()

        toggle = Toggle(label="Connect", button_type="primary", tags=[True], default_size=145)
        toggle.js_on_change("tags", CustomJS(code=js_backpressure_code))
        toggle.on_click(toggle_callback)
        self.toggle = toggle

        # data type select
        datatype_select = Select(
            title="Data type:", value="Image", options=["Image", "Gains"], default_size=145
        )
        self.datatype_select = datatype_select

        # conversion options
        conv_opts_div = Div(text="Conversion options:", margin=(5, 5, 0, 5))
        conv_opts_cbg = CheckboxGroup(
            labels=["Mask", "Gap pixels", "Geometry"], active=[0, 1, 2], default_size=145
        )
        self.conv_opts_cbg = conv_opts_cbg
        self.conv_opts = column(conv_opts_div, conv_opts_cbg)

        # double pixels handling
        double_pixels_div = Div(text="Double pixels:", margin=(5, 5, 0, 5))
        double_pixels_rg = RadioGroup(labels=DP_LABELS, active=0, default_size=145)
        self.double_pixels_rg = double_pixels_rg
        self.double_pixels = column(double_pixels_div, double_pixels_rg)

        # rotate image select
        rotate_values = ["0", "90", "180", "270"]
        rotate_image = Select(
            title="Rotate image (deg):",
            value=rotate_values[0],
            options=rotate_values,
            default_size=145,
        )
        self.rotate_image = rotate_image

        # show only events
        self.show_only_events_toggle = CheckboxGroup(labels=["Show Only Events"], default_size=145)

        # Previous Image slider
        self._prev_image_buffer = deque(maxlen=60)

        def prev_image_slider_callback(_attr, _old, new):
            sv_rt.metadata, sv_rt.image = self._prev_image_buffer[new]
            # TODO: fix this workaround
            sv_rt.aggregated_image = sv_rt.image

        prev_image_slider = Slider(
            start=0, end=59, value_throttled=0, step=1, title="Previous Image", disabled=True,
        )
        prev_image_slider.on_change("value_throttled", prev_image_slider_callback)
        self.prev_image_slider = prev_image_slider

        doc.add_periodic_callback(self._update_toggle_view, 1000)
示例#21
0
def prepare_graph():
    try:
        #____________
        # ---- Importing & Creating base figure from Google Maps ----
        #____________
        knotel_off, patch_source = prepare_data()

        # the map is set to Manhattan
        map_options = GMapOptions(lat=40.741,
                                  lng=-73.995,
                                  map_type="roadmap",
                                  zoom=13,
                                  styles=json.dumps(map_other_config))
        #importing GMap into a Bokeh figure
        p = gmap(MH_GMAPS_KEY,
                 map_options,
                 title="Manhattan Heatmap — A Prototype",
                 plot_width=1070,
                 plot_height=800,
                 output_backend="webgl",
                 tools=['pan', 'wheel_zoom', 'reset', 'box_select', 'tap'])

        #____________
        # ---- OFFICES GLYPH ----
        #____________
        initial_office = Diamond(x="long",
                                 y="lat",
                                 size=18,
                                 fill_color="blue",
                                 fill_alpha=0.7,
                                 line_color="black",
                                 line_alpha=0.7)
        selected_office = Diamond(fill_color="blue", fill_alpha=1)
        nonselected_office = Diamond(fill_color="blue",
                                     fill_alpha=0.15,
                                     line_alpha=0.15)
        # glyph gets added to the plot
        office_renderer = p.add_glyph(knotel_off,
                                      initial_office,
                                      selection_glyph=selected_office,
                                      nonselection_glyph=nonselected_office)
        # hover behavior pointing to office glyph
        office_hover = HoverTool(renderers=[office_renderer],
                                 tooltips=[("Revenue", "@revenue{$00,}"),
                                           ("Rentable SQF",
                                            "@rentable_sqf{00,}"),
                                           ("People density", "@ppl_density"),
                                           ("Address", "@formatted_address")])
        p.add_tools(office_hover)

        #____________
        # ---- TRACT GLYPH ----
        #____________
        tract_renderer = p.patches(xs='xs',
                                   ys='ys',
                                   source=patch_source,
                                   fill_alpha=0,
                                   line_color='red',
                                   line_dash='dashed',
                                   hover_color='red',
                                   hover_fill_alpha=0.5)
        # hack to make tracts unselectable
        initial_tract = Patches(fill_alpha=0,
                                line_color='red',
                                line_dash='dashed')
        tract_renderer.selection_glyph = initial_tract
        tract_renderer.nonselection_glyph = initial_tract
        # hover behavior pointing to tract glyph
        tract_hover = HoverTool(renderers=[tract_renderer],
                                tooltips=[("Median Household Income",
                                           "@MHI{$00,}")],
                                mode='mouse')  #
        p.add_tools(tract_hover)

        # Other figure configurations
        p.yaxis.axis_label = "Latitude"
        p.xaxis.axis_label = "Longitude"
        p.toolbar.active_inspect = [tract_hover, office_hover]
        p.toolbar.active_tap = "auto"
        p.toolbar.active_scroll = "auto"

        #____________
        # ---- Adding widgets & Interactions to figure
        #____________
        # creates a Toggle button
        show_office_toggle = Toggle(label='Show Leased Bldgs.',
                                    active=True,
                                    button_type='primary')

        # callback function for button
        def remove_add_office(active):
            office_renderer.visible = True if active else False

        # event handler
        show_office_toggle.on_click(remove_add_office)

        # same exact logic for tracts. To be combined into a single handler as an improvement!
        show_tract_toggle = Toggle(label='Show Census Tracts',
                                   active=True,
                                   button_type='danger')

        def remove_add_tract(active):
            tract_renderer.visible = True if active else False

        show_tract_toggle.on_click(remove_add_tract)

        # plotting
        layout = row(p, widgetbox(show_office_toggle, show_tract_toggle))

        return layout

    except Exception as err:
        print("ERROR found on prepare_graph:\n{}".format(err))
示例#22
0
    TableColumn(field="N", title="Nitrate"),
    TableColumn(field="O", title="Oxygen"),
    TableColumn(field="Z", title="Tidal Height")
]
data_table = DataTable(source=source, columns=columns,
                       width=300, height=600)


def toggle_callback(attr):
    if tide_toggle.active:
        # Checked *after* press
        tide_toggle.label = "Disable Tides"
    else:
        tide_toggle.label = "Enable Tides"
tide_toggle = Toggle(label="Enable Tides", callback=toggle_ocean)
tide_toggle.on_click(toggle_callback)

download_button = Button(label="Download data", callback=download_data)

go_button = Button(label="Run model")#, callback=check_fish)
go_button.on_click(update_plots)


# Set up app layout
prods = VBox(gas_exchange_slider, productivity_slider)
river = VBox(river_flow_slider, river_N_slider)
tide_run = HBox(tide_toggle, download_button, go_button)
all_settings = VBox(prods, river, tide_run,
                    width=400)

# Add to current document
示例#23
0

def show_undef_config(active):
    if active == True:
        def_config_button.active = False
        structure.update_system(np.zeros(3))
        structure.update_force_indicator_location()

    else:
        pass


def_config_button = Toggle(label="Deformed Configuration",
                           button_type="success",
                           width=175)
def_config_button.on_click(show_def_config)

undef_config_button = Toggle(label="Undeformed Configuration",
                             button_type="success",
                             width=175)
undef_config_button.on_click(show_undef_config)

##################################### (6) #####################################
columns = [
    TableColumn(field="subject", title="Subject"),
    TableColumn(field="modeOne", title="Mode One"),
    TableColumn(field="modeTwo", title="Mode Two"),
    TableColumn(field="modeThree", title="Mode Three"),
]
data_table = DataTable(source=siesmicParameters.informationTable,
                       columns=columns,
start_time = time.time()
Active = False
linklst = []
args = getArgs()
df, link = createDFS()
#Add the fields, buttons and sliders for the search function
search = TextInput(title="Search GO term")
searchDesc = TextInput(title="Search description")
accuracy_slider = RangeSlider(start=0, end=1, value=(0,1), step=.01, title="Accuracy")
if args.importance:
    importance_slider = RangeSlider(start=float(df['Importance'].min()), end=float(df['Importance'].max()), value=(float(df['Importance'].min()), float(df['Importance'].max())), step=.01, title="Importance")
button = Button(label="Search", button_type="success")
button.on_click(update)
if args.models == 2:
    toggleDiff = Toggle(label="Show differences", button_type="success")
    toggleDiff.on_click(showDiff)
#Creates color mappers/legends
color_mapper = LinearColorMapper(palette=cc.gray[::-1], low=0, high=1)
color_bar = ColorBar(color_mapper=color_mapper, ticker=BasicTicker(), title='Accuracy',
                     label_standoff=12, location=(0,0))
#Set plot properties
tools = "pan,box_zoom,reset,xwheel_zoom,save"
plot = figure(title="GO visualizer", sizing_mode='stretch_both', tools = tools, output_backend="webgl", width=1400, height=840, toolbar_location="above", active_scroll="xwheel_zoom")
plot.xgrid.grid_line_color = None
plot.ygrid.grid_line_color = None
plot.axis.visible = False
toolStat1 = "pan,box_zoom,wheel_zoom,reset,save"
plotStat1 = figure(title="Information content VS performance", sizing_mode='stretch_both', tools = toolStat1, output_backend="webgl", width=650, height=650, toolbar_location="above")
plotStat1.xaxis.axis_label = "Information content"
plotStat1.yaxis.axis_label = "Performance"
#Convert dataframe to datasource
示例#25
0
dataset_slider.on_change('value', modify_test_percentage)


def turn_arrow_labels_off(new):
    ''' turn arrow labels on/off '''
    if new:
        p.select(name="arrowLabels").visible = True
        best_root_plot.select(name="arrowLabels").visible = True
        arrow_button.label = "Hide Arrow Labels"
    else:
        p.select(name="arrowLabels").visible = False
        best_root_plot.select(name="arrowLabels").visible = False
        arrow_button.label = "Show Arrow Labels"


arrow_button.on_click(turn_arrow_labels_off)


def update_attributes(new):
    ''' create a new active_attributes_list when any of the checkboxes are selected '''
    active_attributes_list[:] = []
    for i in new:
        active_attributes_list.append(Instance.attr_list[i])
    if selected_root != '' and selected_root not in active_attributes_list:
        apply_changes_button.disabled = True
    else:
        apply_changes_button.disabled = False


attribute_checkbox.on_click(update_attributes)
示例#26
0
class Dashboard:
    """Explorepy dashboard class"""

    def __init__(self, explore=None, mode='signal'):
        """
        Args:
            stream_processor (explorepy.stream_processor.StreamProcessor): Stream processor object
        """
        logger.debug(f"Initializing dashboard in {mode} mode")
        self.explore = explore
        self.stream_processor = self.explore.stream_processor
        self.n_chan = self.stream_processor.device_info['adc_mask'].count(1)
        self.y_unit = DEFAULT_SCALE
        self.offsets = np.arange(1, self.n_chan + 1)[:, np.newaxis].astype(float)
        self.chan_key_list = [CHAN_LIST[i]
                              for i, mask in enumerate(reversed(self.stream_processor.device_info['adc_mask'])) if
                              mask == 1]
        self.exg_mode = 'EEG'
        self.rr_estimator = None
        self.win_length = WIN_LENGTH
        self.mode = mode
        self.exg_fs = self.stream_processor.device_info['sampling_rate']
        self._vis_time_offset = None
        self._baseline_corrector = {"MA_length": 1.5 * EXG_VIS_SRATE,
                                    "baseline": 0}

        # Init ExG data source
        exg_temp = np.zeros((self.n_chan, 2))
        exg_temp[:, 0] = self.offsets[:, 0]
        exg_temp[:, 1] = np.nan
        init_data = dict(zip(self.chan_key_list, exg_temp))
        self._exg_source_orig = ColumnDataSource(data=init_data)
        init_data['t'] = np.array([0., 0.])
        self._exg_source_ds = ColumnDataSource(data=init_data)  # Downsampled ExG data for visualization purposes

        # Init ECG R-peak source
        init_data = dict(zip(['r_peak', 't'], [np.array([None], dtype=np.double), np.array([None], dtype=np.double)]))
        self._r_peak_source = ColumnDataSource(data=init_data)

        # Init marker source
        init_data = dict(zip(['marker', 't'], [np.array([None], dtype=np.double), np.array([None], dtype=np.double)]))
        self._marker_source = ColumnDataSource(data=init_data)

        # Init ORN data source
        init_data = dict(zip(ORN_LIST, np.zeros((9, 1))))
        init_data['t'] = [0.]
        self._orn_source = ColumnDataSource(data=init_data)

        # Init table sources
        self._heart_rate_source = ColumnDataSource(data={'heart_rate': ['NA']})
        self._firmware_source = ColumnDataSource(
            data={'firmware_version': [self.stream_processor.device_info['firmware_version']]}
        )
        self._battery_source = ColumnDataSource(data={'battery': ['NA']})
        self.temperature_source = ColumnDataSource(data={'temperature': ['NA']})
        self.light_source = ColumnDataSource(data={'light': ['NA']})
        self.battery_percent_list = []
        self.server = None

        # Init fft data source
        init_data = dict(zip(self.chan_key_list, np.zeros((self.n_chan, 1))))
        init_data['f'] = np.array([0.])
        self.fft_source = ColumnDataSource(data=init_data)

        # Init impedance measurement source
        init_data = {'channel':   self.chan_key_list,
                     'impedance': ['NA' for i in range(self.n_chan)],
                     'row':       ['1' for i in range(self.n_chan)],
                     'color':     ['black' for i in range(self.n_chan)]}
        self.imp_source = ColumnDataSource(data=init_data)

        # Init timer source
        self._timer_source = ColumnDataSource(data={'timer': ['00:00:00']})

    def start_server(self):
        """Start bokeh server"""
        validate(False)
        logger.debug("Starting bokeh server...")
        port_number = find_free_port()
        logger.info("Opening the dashboard on port: %i", port_number)
        self.server = Server({'/': self._init_doc}, num_procs=1, port=port_number)
        self.server.start()

    def start_loop(self):
        """Start io loop and show the dashboard"""
        logger.debug("Starting bokeh io_loop...")
        self.server.io_loop.add_callback(self.server.show, "/")
        try:
            self.server.io_loop.start()
        except KeyboardInterrupt:
            if self.mode == 'signal':
                logger.info("Got Keyboard Interrupt. The program exits ...")
                self.explore.stop_lsl()
                self.explore.stop_recording()
                os._exit(0)
            else:
                logger.info("Got Keyboard Interrupt. The program exits after disabling the impedance mode ...")
                raise KeyboardInterrupt

    def exg_callback(self, packet):
        """
        Update ExG data in the visualization

        Args:
            packet (explorepy.packet.EEG): Received ExG packet

        """
        time_vector, exg = packet.get_data(self.exg_fs)
        if self._vis_time_offset is None:
            self._vis_time_offset = time_vector[0]
        time_vector -= self._vis_time_offset
        self._exg_source_orig.stream(dict(zip(self.chan_key_list, exg)), rollover=int(self.exg_fs * self.win_length))

        if self.mode == 'signal':
            # Downsampling
            exg = exg[:, ::int(self.exg_fs / EXG_VIS_SRATE)]
            time_vector = time_vector[::int(self.exg_fs / EXG_VIS_SRATE)]

            # Baseline correction
            if self.baseline_widget.active:
                samples_avg = exg.mean(axis=1)
                if self._baseline_corrector["baseline"] is None:
                    self._baseline_corrector["baseline"] = samples_avg
                else:
                    self._baseline_corrector["baseline"] -= (
                            (self._baseline_corrector["baseline"] - samples_avg) / self._baseline_corrector["MA_length"] *
                            exg.shape[1])
                exg -= self._baseline_corrector["baseline"][:, np.newaxis]
            else:
                self._baseline_corrector["baseline"] = None

            # Update ExG unit
            exg = self.offsets + exg / self.y_unit
            new_data = dict(zip(self.chan_key_list, exg))
            new_data['t'] = time_vector
            self.doc.add_next_tick_callback(partial(self._update_exg, new_data=new_data))

    def orn_callback(self, packet):
        """Update orientation data

        Args:
            packet (explorepy.packet.Orientation): Orientation packet
        """
        if self.tabs.active != 1:
            return
        timestamp, orn_data = packet.get_data()
        if self._vis_time_offset is None:
            self._vis_time_offset = timestamp[0]
        timestamp -= self._vis_time_offset
        new_data = dict(zip(ORN_LIST, np.array(orn_data)[:, np.newaxis]))
        new_data['t'] = timestamp
        self.doc.add_next_tick_callback(partial(self._update_orn, new_data=new_data))

    def info_callback(self, packet):
        """Update device information in the dashboard

        Args:
            packet (explorepy.packet.Environment): Environment/DeviceInfo packet

        """
        new_info = packet.get_data()
        for key in new_info.keys():
            data = {key: new_info[key]}
            if key == 'firmware_version':
                self.doc.add_next_tick_callback(partial(self._update_fw_version, new_data=data))
            elif key == 'battery':
                self.battery_percent_list.append(new_info[key][0])
                if len(self.battery_percent_list) > BATTERY_N_MOVING_AVERAGE:
                    del self.battery_percent_list[0]
                value = int(np.mean(self.battery_percent_list) / 5) * 5
                if value < 1:
                    value = 1
                self.doc.add_next_tick_callback(partial(self._update_battery, new_data={key: [value]}))
            elif key == 'temperature':
                self.doc.add_next_tick_callback(partial(self._update_temperature, new_data=data))
            elif key == 'light':
                data[key] = [int(data[key][0])]
                self.doc.add_next_tick_callback(partial(self._update_light, new_data=data))
            else:
                logger.warning("There is no field named: " + key)

    def marker_callback(self, packet):
        """Update markers
        Args:
            packet (explorepy.packet.EventMarker): Event marker packet
        """
        if self.mode == "impedance":
            return
        timestamp, _ = packet.get_data()
        if self._vis_time_offset is None:
            self._vis_time_offset = timestamp[0]
        timestamp -= self._vis_time_offset
        new_data = dict(zip(['marker', 't', 'code'], [np.array([0.01, self.n_chan + 0.99, None], dtype=np.double),
                                                      np.array([timestamp[0], timestamp[0], None], dtype=np.double)]))
        self.doc.add_next_tick_callback(partial(self._update_marker, new_data=new_data))

    def impedance_callback(self, packet):
        """Update impedances

        Args:
             packet (explorepy.packet.EEG): ExG packet
        """
        if self.mode == "impedance":
            imp = packet.get_impedances()
            color = []
            imp_status = []
            for value in imp:
                if value > 500:
                    color.append("black")
                    imp_status.append("Open")
                elif value > 100:
                    color.append("red")
                    imp_status.append(str(round(value, 0)) + " K\u03A9")
                elif value > 50:
                    color.append("orange")
                    imp_status.append(str(round(value, 0)) + " K\u03A9")
                elif value > 10:
                    color.append("yellow")
                    imp_status.append(str(round(value, 0)) + " K\u03A9")
                elif value > 5:
                    imp_status.append(str(round(value, 0)) + " K\u03A9")
                    color.append("green")
                else:
                    color.append("green")
                    imp_status.append("<5K\u03A9")  # As the ADS is not precise in low values.

            data = {"impedance": imp_status,
                    'channel':   self.chan_key_list,
                    'row':       ['1' for i in range(self.n_chan)],
                    'color':     color
                    }
            self.doc.add_next_tick_callback(partial(self._update_imp, new_data=data))
        else:
            raise RuntimeError("Trying to compute impedances while the dashboard is not in Impedance mode!")

    @gen.coroutine
    @without_property_validation
    def _update_exg(self, new_data):
        self._exg_source_ds.stream(new_data, rollover=int(2 * EXG_VIS_SRATE * WIN_LENGTH))

    @gen.coroutine
    @without_property_validation
    def _update_orn(self, new_data):
        self._orn_source.stream(new_data, rollover=int(2 * WIN_LENGTH * ORN_SRATE))

    @gen.coroutine
    @without_property_validation
    def _update_fw_version(self, new_data):
        self._firmware_source.stream(new_data, rollover=1)

    @gen.coroutine
    @without_property_validation
    def _update_battery(self, new_data):
        self._battery_source.stream(new_data, rollover=1)

    @gen.coroutine
    @without_property_validation
    def _update_temperature(self, new_data):
        self.temperature_source.stream(new_data, rollover=1)

    @gen.coroutine
    @without_property_validation
    def _update_light(self, new_data):
        self.light_source.stream(new_data, rollover=1)

    @gen.coroutine
    @without_property_validation
    def _update_marker(self, new_data):
        self._marker_source.stream(new_data=new_data, rollover=100)

    @gen.coroutine
    @without_property_validation
    def _update_imp(self, new_data):
        self.imp_source.stream(new_data, rollover=self.n_chan)

    @gen.coroutine
    @without_property_validation
    def _update_fft(self):
        """ Update spectral frequency analysis plot"""
        # Check if the tab is active and if EEG mode is active
        if (self.tabs.active != 2) or (self.exg_mode != 'EEG'):
            return

        exg_data = np.array([self._exg_source_orig.data[key] for key in self.chan_key_list])

        if exg_data.shape[1] < self.exg_fs * 5:
            return
        fft_content, freq = get_fft(exg_data, self.exg_fs)
        data = dict(zip(self.chan_key_list, fft_content))
        data['f'] = freq
        self.fft_source.data = data

    @gen.coroutine
    @without_property_validation
    def _update_heart_rate(self):
        """Detect R-peaks and update the plot and heart rate"""
        if self.exg_mode == 'EEG':
            self._heart_rate_source.stream({'heart_rate': ['NA']}, rollover=1)
            return
        if CHAN_LIST[0] not in self.chan_key_list:
            logger.warning('Heart rate estimation works only when channel 1 is enabled.')
            return
        if self.rr_estimator is None:
            self.rr_estimator = HeartRateEstimator(fs=self.exg_fs)
            # Init R-peaks plot
            self.exg_plot.circle(x='t', y='r_peak', source=self._r_peak_source,
                                 fill_color="red", size=8)

        ecg_data = (np.array(self._exg_source_ds.data['Ch1'])[-2 * EXG_VIS_SRATE:] - self.offsets[0]) * self.y_unit
        time_vector = np.array(self._exg_source_ds.data['t'])[-2 * EXG_VIS_SRATE:]

        # Check if the peak2peak value is bigger than threshold
        if (np.ptp(ecg_data) < V_TH[0]) or (np.ptp(ecg_data) > V_TH[1]):
            logger.warning("P2P value larger or less than threshold. Cannot compute heart rate!")
            return

        peaks_time, peaks_val = self.rr_estimator.estimate(ecg_data, time_vector)
        peaks_val = (np.array(peaks_val) / self.y_unit) + self.offsets[0]
        if peaks_time:
            data = dict(zip(['r_peak', 't'], [peaks_val, peaks_time]))
            self._r_peak_source.stream(data, rollover=50)

        # Update heart rate cell
        estimated_heart_rate = self.rr_estimator.heart_rate
        data = {'heart_rate': [estimated_heart_rate]}
        self._heart_rate_source.stream(data, rollover=1)

    @gen.coroutine
    @without_property_validation
    def _change_scale(self, attr, old, new):
        """Change y-scale of ExG plot"""
        logger.debug(f"ExG scale has been changed from {old} to {new}")
        new, old = SCALE_MENU[new], SCALE_MENU[old]
        old_unit = 10 ** (-old)
        self.y_unit = 10 ** (-new)

        for chan, value in self._exg_source_ds.data.items():
            if chan in self.chan_key_list:
                temp_offset = self.offsets[self.chan_key_list.index(chan)]
                self._exg_source_ds.data[chan] = (value - temp_offset) * (old_unit / self.y_unit) + temp_offset
        self._r_peak_source.data['r_peak'] = (np.array(self._r_peak_source.data['r_peak']) - self.offsets[0]) * \
                                             (old_unit / self.y_unit) + self.offsets[0]

    @gen.coroutine
    @without_property_validation
    def _change_t_range(self, attr, old, new):
        """Change time range"""
        logger.debug(f"Time scale has been changed from {old} to {new}")
        self._set_t_range(TIME_RANGE_MENU[new])

    @gen.coroutine
    def _change_mode(self, attr, old, new):
        """Set EEG or ECG mode"""
        logger.debug(f"ExG mode has been changed to {new}")
        self.exg_mode = new

    def _init_doc(self, doc):
        self.doc = doc
        self.doc.title = "Explore Dashboard"
        with open(os.path.join(os.path.dirname(__file__), 'templates', 'index.html')) as f:
            index_template = Template(f.read())
        doc.template = index_template
        self.doc.theme = Theme(os.path.join(os.path.dirname(__file__), 'theme.yaml'))
        self._init_plots()
        m_widgetbox = self._init_controls()

        # Create tabs
        if self.mode == "signal":
            exg_tab = Panel(child=self.exg_plot, title="ExG Signal")
            orn_tab = Panel(child=column([self.acc_plot, self.gyro_plot, self.mag_plot], sizing_mode='scale_width'),
                            title="Orientation")
            fft_tab = Panel(child=self.fft_plot, title="Spectral analysis")
            self.tabs = Tabs(tabs=[exg_tab, orn_tab, fft_tab], width=400, sizing_mode='scale_width')
            self.recorder_widget = self._init_recorder()
            self.push2lsl_widget = self._init_push2lsl()
            self.set_marker_widget = self._init_set_marker()
            self.baseline_widget = CheckboxGroup(labels=['Baseline correction'], active=[0])

        elif self.mode == "impedance":
            imp_tab = Panel(child=self.imp_plot, title="Impedance")
            self.tabs = Tabs(tabs=[imp_tab], width=500, sizing_mode='scale_width')
        banner = Div(text=""" <a href="https://www.mentalab.com"><img src=
        "https://images.squarespace-cdn.com/content/5428308ae4b0701411ea8aaf/1505653866447-R24N86G5X1HFZCD7KBWS/
        Mentalab%2C+Name+copy.png?format=1500w&content-type=image%2Fpng" alt="Mentalab"  width="225" height="39">""",
                     width=1500, height=50, css_classes=["banner"], align='center', sizing_mode="stretch_width")
        heading = Div(text=""" """, height=2, sizing_mode="stretch_width")
        if self.mode == 'signal':
            layout = column([heading,
                             banner,
                             row(m_widgetbox,
                                 Spacer(width=10, height=300),
                                 self.tabs,
                                 Spacer(width=10, height=300),
                                 column(Spacer(width=170, height=50), self.baseline_widget, self.recorder_widget,
                                        self.set_marker_widget, self.push2lsl_widget),
                                 Spacer(width=50, height=300)),
                             ],
                            sizing_mode="stretch_both")

        elif self.mode == 'impedance':
            layout = column(banner,
                            Spacer(width=600, height=20),
                            row([m_widgetbox, Spacer(width=25, height=500), self.tabs])
                            )
        self.doc.add_root(layout)
        self.doc.add_periodic_callback(self._update_fft, 2000)
        self.doc.add_periodic_callback(self._update_heart_rate, 2000)
        if self.stream_processor:
            self.stream_processor.subscribe(topic=TOPICS.filtered_ExG, callback=self.exg_callback)
            self.stream_processor.subscribe(topic=TOPICS.raw_orn, callback=self.orn_callback)
            self.stream_processor.subscribe(topic=TOPICS.device_info, callback=self.info_callback)
            self.stream_processor.subscribe(topic=TOPICS.marker, callback=self.marker_callback)
            self.stream_processor.subscribe(topic=TOPICS.env, callback=self.info_callback)
            self.stream_processor.subscribe(topic=TOPICS.imp, callback=self.impedance_callback)

    def _init_plots(self):
        """Initialize all plots in the dashboard"""
        self.exg_plot = figure(y_range=(0.01, self.n_chan + 1 - 0.01), y_axis_label='Voltage', x_axis_label='Time (s)',
                               title="ExG signal",
                               plot_height=250, plot_width=500,
                               y_minor_ticks=int(10),
                               tools=[ResetTool()], active_scroll=None, active_drag=None,
                               active_inspect=None, active_tap=None, sizing_mode="scale_width")

        self.mag_plot = figure(y_axis_label='Mag [mgauss/LSB]', x_axis_label='Time (s)',
                               plot_height=100, plot_width=500,
                               tools=[ResetTool()], active_scroll=None, active_drag=None,
                               active_inspect=None, active_tap=None, sizing_mode="scale_width")
        self.acc_plot = figure(y_axis_label='Acc [mg/LSB]',
                               plot_height=75, plot_width=500,
                               tools=[ResetTool()], active_scroll=None, active_drag=None,
                               active_inspect=None, active_tap=None, sizing_mode="scale_width")
        self.acc_plot.xaxis.visible = False
        self.gyro_plot = figure(y_axis_label='Gyro [mdps/LSB]',
                                plot_height=75, plot_width=500,
                                tools=[ResetTool()], active_scroll=None, active_drag=None,
                                active_inspect=None, active_tap=None, sizing_mode="scale_width")
        self.gyro_plot.xaxis.visible = False

        self.fft_plot = figure(y_axis_label='Amplitude (uV)', x_axis_label='Frequency (Hz)', title="FFT",
                               x_range=(0, 70), plot_height=250, plot_width=500, y_axis_type="log",
                               tools=[BoxZoomTool(), ResetTool()], active_scroll=None, active_drag=None,
                               active_tap=None,
                               sizing_mode="scale_width")

        self.imp_plot = self._init_imp_plot()

        # Set yaxis properties
        self.exg_plot.yaxis.ticker = SingleIntervalTicker(interval=1, num_minor_ticks=0)

        # Initial plot line
        for i in range(self.n_chan):
            self.exg_plot.line(x='t', y=self.chan_key_list[i], source=self._exg_source_ds,
                               line_width=1.0, alpha=.9, line_color="#42C4F7")
            self.fft_plot.line(x='f', y=self.chan_key_list[i], source=self.fft_source,
                               legend_label=self.chan_key_list[i] + " ",
                               line_width=1.5, alpha=.9, line_color=FFT_COLORS[i])
        self.fft_plot.yaxis.axis_label_text_font_style = 'normal'
        self.exg_plot.line(x='t', y='marker', source=self._marker_source,
                           line_width=1, alpha=.8, line_color='#7AB904', line_dash="4 4")

        for i in range(3):
            self.acc_plot.line(x='t', y=ORN_LIST[i], source=self._orn_source, legend_label=ORN_LIST[i] + " ",
                               line_width=1.5, line_color=LINE_COLORS[i], alpha=.9)
            self.gyro_plot.line(x='t', y=ORN_LIST[i + 3], source=self._orn_source, legend_label=ORN_LIST[i + 3] + " ",
                                line_width=1.5, line_color=LINE_COLORS[i], alpha=.9)
            self.mag_plot.line(x='t', y=ORN_LIST[i + 6], source=self._orn_source, legend_label=ORN_LIST[i + 6] + " ",
                               line_width=1.5, line_color=LINE_COLORS[i], alpha=.9)

        # Set x_range
        self.plot_list = [self.exg_plot, self.acc_plot, self.gyro_plot, self.mag_plot]
        self._set_t_range(WIN_LENGTH)

        # Set the formatting of yaxis ticks' labels
        self.exg_plot.yaxis.major_label_overrides = dict(zip(range(1, self.n_chan + 1), self.chan_key_list))
        for plot in self.plot_list:
            plot.toolbar.autohide = True
            plot.yaxis.axis_label_text_font_style = 'normal'
            if len(plot.legend) != 0:
                plot.legend.location = "bottom_left"
                plot.legend.orientation = "horizontal"
                plot.legend.padding = 2

    def _init_imp_plot(self):
        plot = figure(plot_width=600, plot_height=200, x_range=self.chan_key_list[0:self.n_chan],
                      y_range=[str(1)], toolbar_location=None, sizing_mode="scale_width")

        plot.circle(x='channel', y="row", size=50, source=self.imp_source, fill_alpha=0.6, color="color",
                    line_color='color', line_width=2)

        text_props = {"source":          self.imp_source, "text_align": "center",
                      "text_color":      "white", "text_baseline": "middle", "text_font": "helvetica",
                      "text_font_style": "bold"}

        x = dodge("channel", -0.1, range=plot.x_range)

        plot.text(x=x, y=dodge('row', -.35, range=plot.y_range),
                  text="impedance", **text_props).glyph.text_font_size = "10pt"
        plot.text(x=x, y=dodge('row', -.25, range=plot.y_range), text="channel",
                  **text_props).glyph.text_font_size = "12pt"

        plot.outline_line_color = None
        plot.grid.grid_line_color = None
        plot.axis.axis_line_color = None
        plot.axis.major_tick_line_color = None
        plot.axis.major_label_standoff = 0
        plot.axis.visible = False
        return plot

    def _init_controls(self):
        """Initialize all controls in the dashboard"""
        # EEG/ECG Radio button
        self.mode_control = widgets.Select(title="Signal", value='EEG', options=MODE_LIST, width=170, height=50)
        self.mode_control.on_change('value', self._change_mode)

        self.t_range = widgets.Select(title="Time window", value="10 s", options=list(TIME_RANGE_MENU.keys()),
                                      width=170, height=50)
        self.t_range.on_change('value', self._change_t_range)
        self.y_scale = widgets.Select(title="Y-axis Scale", value="1 mV", options=list(SCALE_MENU.keys()),
                                      width=170, height=50)
        self.y_scale.on_change('value', self._change_scale)

        # Create device info tables
        columns = [widgets.TableColumn(field='heart_rate', title="Heart Rate (bpm)")]
        self.heart_rate = widgets.DataTable(source=self._heart_rate_source, index_position=None, sortable=False,
                                            reorderable=False,
                                            columns=columns, width=170, height=50)

        columns = [widgets.TableColumn(field='firmware_version', title="Firmware Version")]
        self.firmware = widgets.DataTable(source=self._firmware_source, index_position=None, sortable=False,
                                          reorderable=False,
                                          columns=columns, width=170, height=50)

        columns = [widgets.TableColumn(field='battery', title="Battery (%)")]
        self.battery = widgets.DataTable(source=self._battery_source, index_position=None, sortable=False,
                                         reorderable=False,
                                         columns=columns, width=170, height=50)

        columns = [widgets.TableColumn(field='temperature', title="Device temperature (C)")]
        self.temperature = widgets.DataTable(source=self.temperature_source, index_position=None, sortable=False,
                                             reorderable=False, columns=columns, width=170, height=50)

        columns = [widgets.TableColumn(field='light', title="Light (Lux)")]
        self.light = widgets.DataTable(source=self.light_source, index_position=None, sortable=False, reorderable=False,
                                       columns=columns, width=170, height=50)
        if self.mode == 'signal':
            widget_list = [Spacer(width=170, height=30), self.mode_control, self.y_scale, self.t_range, self.heart_rate,
                           self.battery, self.temperature, self.firmware]
        elif self.mode == 'impedance':
            widget_list = [Spacer(width=170, height=40), self.battery, self.temperature, self.firmware]

        widget_box = widgetbox(widget_list, width=175, height=450, sizing_mode='fixed')
        return widget_box

    def _init_recorder(self):
        self.rec_button = Toggle(label=u"\u25CF  Record", button_type="default", active=False,
                                 width=170, height=35)
        self.file_name_widget = TextInput(value="test_file", title="File name:", width=170, height=50)
        self.file_type_widget = RadioGroup(labels=["EDF (BDF+)", "CSV"], active=0, width=170, height=50)

        columns = [widgets.TableColumn(field='timer', title="Record time",
                                       formatter=widgets.StringFormatter(text_align='center'))]
        self.timer = widgets.DataTable(source=self._timer_source, index_position=None, sortable=False,
                                       reorderable=False,
                                       header_row=False, columns=columns,
                                       width=170, height=50, css_classes=["timer_widget"])

        self.rec_button.on_click(self._toggle_rec)
        return column([Spacer(width=170, height=5), self.file_name_widget, self.file_type_widget, self.rec_button,
                      self.timer], width=170, height=200, sizing_mode='fixed')

    def _toggle_rec(self, active):
        logger.debug(f"Pressed record button -> {active}")
        if active:
            self.event_code_input.disabled = False
            self.marker_button.disabled = False
            if self.explore.is_connected:
                self.explore.record_data(file_name=self.file_name_widget.value,
                                         file_type=['edf', 'csv'][self.file_type_widget.active],
                                         do_overwrite=True)
                self.rec_button.label = u"\u25A0  Stop"
                self.rec_start_time = datetime.now()
                self.rec_timer_id = self.doc.add_periodic_callback(self._timer_callback, 1000)
            else:
                self.rec_button.active = False
                self.doc.remove_periodic_callback(self.rec_timer_id)
                self.doc.add_next_tick_callback(partial(self._update_rec_timer, new_data={'timer': '00:00:00'}))
        else:
            self.explore.stop_recording()
            self.rec_button.label = u"\u25CF  Record"
            self.doc.add_next_tick_callback(partial(self._update_rec_timer, new_data={'timer': '00:00:00'}))
            self.doc.remove_periodic_callback(self.rec_timer_id)
            if not self.push2lsl_button.active:
                self.event_code_input.disabled = True
                self.marker_button.disabled = True

    def _timer_callback(self):
        t_delta = (datetime.now() - self.rec_start_time).seconds
        timer_text = ':'.join([str(int(t_delta / 3600)).zfill(2), str(int(t_delta / 60) % 60).zfill(2),
                               str(int(t_delta % 60)).zfill(2)])
        data = {'timer': timer_text}
        self.doc.add_next_tick_callback(partial(self._update_rec_timer, new_data=data))

    def _init_push2lsl(self):
        push2lsl_title = Div(text="""Push to LSL""", width=170, height=10)
        self.push2lsl_button = Toggle(label=u"\u25CF  Start", button_type="default", active=False,
                                      width=170, height=35)
        self.push2lsl_button.on_click(self._toggle_push2lsl)
        return column([Spacer(width=170, height=30), push2lsl_title, self.push2lsl_button],
                      width=170, height=200, sizing_mode='fixed')

    def _toggle_push2lsl(self, active):
        logger.debug(f"Pressed push2lsl button -> {active}")
        if active:
            self.event_code_input.disabled = False
            self.marker_button.disabled = False
            if self.explore.is_connected:
                self.explore.push2lsl()
                self.push2lsl_button.label = u"\u25A0  Stop"
            else:
                self.push2lsl_button.active = False
        else:
            self.explore.stop_lsl()
            self.push2lsl_button.label = u"\u25CF  Start"
            if not self.rec_button.active:
                self.event_code_input.disabled = True
                self.marker_button.disabled = True

    def _init_set_marker(self):
        self.marker_button = Button(label=u"Set", button_type="default", width=80, height=31, disabled=True)
        self.event_code_input = TextInput(value="8", title="Event code:", width=80, disabled=True)
        self.event_code_input.on_change('value', self._check_marker_value)
        self.marker_button.on_click(self._set_marker)
        return column([Spacer(width=170, height=5),
                      row([self.event_code_input,
                          column(Spacer(width=50, height=19), self.marker_button)], height=50, width=170)],
                      width=170, height=50, sizing_mode='fixed'
                      )

    def _set_marker(self):
        code = self.event_code_input.value
        self.stream_processor.set_marker(int(code))

    def _check_marker_value(self, attr, old, new):
        try:
            code = int(self.event_code_input.value)
            if code < 7 or code > 65535:
                raise ValueError('Value must be an integer between 8 and 65535')
        except ValueError:
            self.event_code_input.value = "7<val<65535"

    @gen.coroutine
    @without_property_validation
    def _update_rec_timer(self, new_data):
        self._timer_source.stream(new_data, rollover=1)

    def _set_t_range(self, t_length):
        """Change time range of ExG and orientation plots"""
        for plot in self.plot_list:
            self.win_length = int(t_length)
            plot.x_range.follow = "end"
            plot.x_range.follow_interval = t_length
            plot.x_range.range_padding = 0.
            plot.x_range.min_interval = t_length
示例#27
0
        valor_fijo = False


LABELS_V = ['Valor fijo', 'Sinusoide']
voltajes_rg = RadioGroup(labels=LABELS_V, active=0)
voltajes_rg.on_change('active', callback_voltajes)


#BOTON INICIAR
def activador(active):
    global manual
    manual = not manual


act = Toggle(label="Manual/Automático", button_type="success")
act.on_click(activador)


#BOTON GUARDAR
def guardador(active):
    global guardar, t, T_init
    guardar = not guardar
    if guardar:
        T_init = time.time() - t


boton_guardar = Toggle(label="Guardar", button_type="success")
boton_guardar.on_click(guardador)


# TEXT INPUT
示例#28
0
def create(palm):
    connected = False
    current_message = None
    stream_t = 0

    doc = curdoc()

    # Streaked and reference waveforms plot
    waveform_plot = Plot(
        title=Title(text="eTOF waveforms"),
        x_range=DataRange1d(),
        y_range=DataRange1d(),
        plot_height=PLOT_CANVAS_HEIGHT,
        plot_width=PLOT_CANVAS_WIDTH,
        toolbar_location='right',
    )

    # ---- tools
    waveform_plot.toolbar.logo = None
    waveform_plot.add_tools(PanTool(), BoxZoomTool(), WheelZoomTool(),
                            ResetTool())

    # ---- axes
    waveform_plot.add_layout(LinearAxis(axis_label='Photon energy, eV'),
                             place='below')
    waveform_plot.add_layout(LinearAxis(axis_label='Intensity',
                                        major_label_orientation='vertical'),
                             place='left')

    # ---- grid lines
    waveform_plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    waveform_plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- line glyphs
    waveform_source = ColumnDataSource(
        dict(x_str=[], y_str=[], x_ref=[], y_ref=[]))
    waveform_ref_line = waveform_plot.add_glyph(
        waveform_source, Line(x='x_ref', y='y_ref', line_color='blue'))
    waveform_str_line = waveform_plot.add_glyph(
        waveform_source, Line(x='x_str', y='y_str', line_color='red'))

    # ---- legend
    waveform_plot.add_layout(
        Legend(items=[("reference",
                       [waveform_ref_line]), ("streaked",
                                              [waveform_str_line])]))
    waveform_plot.legend.click_policy = "hide"

    # Cross-correlation plot
    xcorr_plot = Plot(
        title=Title(text="Waveforms cross-correlation"),
        x_range=DataRange1d(),
        y_range=DataRange1d(),
        plot_height=PLOT_CANVAS_HEIGHT,
        plot_width=PLOT_CANVAS_WIDTH,
        toolbar_location='right',
    )

    # ---- tools
    xcorr_plot.toolbar.logo = None
    xcorr_plot.add_tools(PanTool(), BoxZoomTool(), WheelZoomTool(),
                         ResetTool())

    # ---- axes
    xcorr_plot.add_layout(LinearAxis(axis_label='Energy shift, eV'),
                          place='below')
    xcorr_plot.add_layout(LinearAxis(axis_label='Cross-correlation',
                                     major_label_orientation='vertical'),
                          place='left')

    # ---- grid lines
    xcorr_plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    xcorr_plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- line glyphs
    xcorr_source = ColumnDataSource(dict(lags=[], xcorr=[]))
    xcorr_plot.add_glyph(xcorr_source,
                         Line(x='lags', y='xcorr', line_color='purple'))

    # ---- vertical span
    xcorr_center_span = Span(location=0, dimension='height')
    xcorr_plot.add_layout(xcorr_center_span)

    # Delays plot
    pulse_delay_plot = Plot(
        title=Title(text="Pulse delays"),
        x_range=DataRange1d(),
        y_range=DataRange1d(),
        plot_height=PLOT_CANVAS_HEIGHT,
        plot_width=PLOT_CANVAS_WIDTH,
        toolbar_location='right',
    )

    # ---- tools
    pulse_delay_plot.toolbar.logo = None
    pulse_delay_plot.add_tools(PanTool(), BoxZoomTool(), WheelZoomTool(),
                               ResetTool())

    # ---- axes
    pulse_delay_plot.add_layout(LinearAxis(axis_label='Pulse number'),
                                place='below')
    pulse_delay_plot.add_layout(
        LinearAxis(axis_label='Pulse delay (uncalib), eV',
                   major_label_orientation='vertical'),
        place='left',
    )

    # ---- grid lines
    pulse_delay_plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    pulse_delay_plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- line glyphs
    pulse_delay_source = ColumnDataSource(dict(x=[], y=[]))
    pulse_delay_plot.add_glyph(pulse_delay_source,
                               Line(x='x', y='y', line_color='steelblue'))

    # Pulse lengths plot
    pulse_length_plot = Plot(
        title=Title(text="Pulse lengths"),
        x_range=DataRange1d(),
        y_range=DataRange1d(),
        plot_height=PLOT_CANVAS_HEIGHT,
        plot_width=PLOT_CANVAS_WIDTH,
        toolbar_location='right',
    )

    # ---- tools
    pulse_length_plot.toolbar.logo = None
    pulse_length_plot.add_tools(PanTool(), BoxZoomTool(), WheelZoomTool(),
                                ResetTool())

    # ---- axes
    pulse_length_plot.add_layout(LinearAxis(axis_label='Pulse number'),
                                 place='below')
    pulse_length_plot.add_layout(
        LinearAxis(axis_label='Pulse length (uncalib), eV',
                   major_label_orientation='vertical'),
        place='left',
    )

    # ---- grid lines
    pulse_length_plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    pulse_length_plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- line glyphs
    pulse_length_source = ColumnDataSource(dict(x=[], y=[]))
    pulse_length_plot.add_glyph(pulse_length_source,
                                Line(x='x', y='y', line_color='steelblue'))

    # Image buffer slider
    def buffer_slider_callback(_attr, _old, new):
        message = receiver.data_buffer[new]
        doc.add_next_tick_callback(partial(update, message=message))

    buffer_slider = Slider(
        start=0,
        end=59,
        value=0,
        step=1,
        title="Buffered Image",
        callback_policy='throttle',
        callback_throttle=500,
    )
    buffer_slider.on_change('value', buffer_slider_callback)

    # Connect toggle button
    def connect_toggle_callback(state):
        nonlocal connected
        if state:
            connected = True
            connect_toggle.label = 'Connecting'
            connect_toggle.button_type = 'default'

        else:
            connected = False
            connect_toggle.label = 'Connect'
            connect_toggle.button_type = 'default'

    connect_toggle = Toggle(label="Connect", button_type='default', width=250)
    connect_toggle.on_click(connect_toggle_callback)

    # Intensity stream reset button
    def reset_button_callback():
        nonlocal stream_t
        stream_t = 1  # keep the latest point in order to prevent full axis reset

    reset_button = Button(label="Reset", button_type='default', width=250)
    reset_button.on_click(reset_button_callback)

    # Stream update coroutine
    async def update(message):
        nonlocal stream_t
        if connected and receiver.state == 'receiving':
            y_ref = message[receiver.reference].value[np.newaxis, :]
            y_str = message[receiver.streaked].value[np.newaxis, :]

            delay, length, debug_data = palm.process({
                '0': y_ref,
                '1': y_str
            },
                                                     debug=True)
            prep_data, lags, corr_res_uncut, _ = debug_data

            waveform_source.data.update(
                x_str=palm.energy_range,
                y_str=prep_data['1'][0, :],
                x_ref=palm.energy_range,
                y_ref=prep_data['0'][0, :],
            )

            xcorr_source.data.update(lags=lags, xcorr=corr_res_uncut[0, :])
            xcorr_center_span.location = delay[0]

            pulse_delay_source.stream({
                'x': [stream_t],
                'y': [delay]
            },
                                      rollover=120)
            pulse_length_source.stream({
                'x': [stream_t],
                'y': [length]
            },
                                       rollover=120)

            stream_t += 1

    # Periodic callback to fetch data from receiver
    async def internal_periodic_callback():
        nonlocal current_message
        if waveform_plot.inner_width is None:
            # wait for the initialization to finish, thus skip this periodic callback
            return

        if connected:
            if receiver.state == 'polling':
                connect_toggle.label = 'Polling'
                connect_toggle.button_type = 'warning'

            elif receiver.state == 'receiving':
                connect_toggle.label = 'Receiving'
                connect_toggle.button_type = 'success'

                # Set slider to the right-most position
                if len(receiver.data_buffer) > 1:
                    buffer_slider.end = len(receiver.data_buffer) - 1
                    buffer_slider.value = len(receiver.data_buffer) - 1

                if receiver.data_buffer:
                    current_message = receiver.data_buffer[-1]

        doc.add_next_tick_callback(partial(update, message=current_message))

    doc.add_periodic_callback(internal_periodic_callback, 1000)

    # assemble
    tab_layout = column(
        row(
            column(waveform_plot, xcorr_plot),
            Spacer(width=30),
            column(buffer_slider, row(connect_toggle, reset_button)),
        ),
        row(pulse_delay_plot, Spacer(width=10), pulse_length_plot),
    )

    return Panel(child=tab_layout, title="Stream")
class TrimerFigure:
    order_functions: Dict[str, Any] = {
        "None": None,
        "Orient": create_orient_ordering(threshold=0.75),
        "Num Neighs": create_neigh_ordering(neighbours=6),
    }
    controls_width = 400

    _frame = None
    plot = None
    _temperatures = None
    _pressures = None
    _crystals = None
    _iter_index = None

    _callback = None

    def __init__(self, doc, directory: Path = None, models=None) -> None:
        self._doc = doc
        self._trajectory = [None]

        if directory is None:
            directory = Path.cwd()
        self._source = ColumnDataSource({
            "x": [],
            "y": [],
            "orientation": [],
            "colour": [],
            "radius": []
        })

        if models is not None:
            if not isinstance(models, (list, tuple)):
                raise ValueError(
                    "The argument models has to have type list or tuple")

            logger.debug("Found additional models: %s", models)
            for model in models:
                model = Path(model)
                self.order_functions[model.stem] = create_ml_ordering(model)

        self.directory = directory
        self.initialise_directory()
        self._filename_div = Div(text="", width=self.controls_width)

        self.initialise_trajectory_interface()
        self.update_current_trajectory(None, None, None)
        self._playing = False

        # Initialise user interface
        self.initialise_media_interface()

        self.initialise_doc()

    def initialise_directory(self) -> None:
        self.variable_selection = parse_directory(self.directory,
                                                  glob="dump*.gsd")
        logger.debug("Pressures present: %s", self.variable_selection.keys())

        self._pressures = sorted(list(self.variable_selection.keys()))
        self._pressure_button = RadioButtonGroup(
            name="Pressure ",
            labels=self._pressures,
            active=0,
            width=self.controls_width,
        )
        self._pressure_button.on_change("active",
                                        self.update_temperature_button)
        pressure = self._pressures[self._pressure_button.active]

        self._temperatures = sorted(
            list(self.variable_selection[pressure].keys()))
        self._temperature_button = Select(
            name="Temperature",
            options=self._temperatures,
            value=self._temperatures[0],
            width=self.controls_width,
        )
        self._temperature_button.on_change("value", self.update_crystal_button)

        temperature = self._temperature_button.value
        self._crystals = sorted(
            list(self.variable_selection[pressure][temperature].keys()))
        self._crystal_button = RadioButtonGroup(name="Crystal",
                                                labels=self._crystals,
                                                active=0,
                                                width=self.controls_width)
        self._crystal_button.on_change("active", self.update_index_button)

        crystal = self._crystals[self._crystal_button.active]
        self._iter_index = sorted(
            list(self.variable_selection[pressure][temperature]
                 [crystal].keys()))
        self._iter_index_button = Select(
            name="Iteration Index",
            options=self._iter_index,
            value=self._iter_index[0],
            width=self.controls_width,
        )
        self._iter_index_button.on_change("value",
                                          self.update_current_trajectory)

    @property
    def pressure(self) -> Optional[str]:
        if self._pressures is None:
            return None
        return self._pressures[self._pressure_button.active]

    @property
    def temperature(self) -> Optional[str]:
        if self._temperatures is None:
            return None
        return self._temperature_button.value

    @property
    def crystal(self) -> Optional[str]:
        logger.debug("Current crystal %s from %s", self._crystal_button.active,
                     self._crystals)
        if self._crystals is None:
            return None
        return self._crystals[self._crystal_button.active]

    @property
    def iter_index(self) -> Optional[str]:
        logger.debug("Current index %s from %s", self._iter_index_button.value,
                     self._iter_index)
        return self._iter_index_button.value

    def update_temperature_button(self, attr, old, new):
        self._temperatures = sorted(
            list(self.variable_selection[self.pressure].keys()))

        self._temperature_button.options = self._temperatures
        self._temperature_button.value = self._temperatures[0]
        self.update_crystal_button(None, None, None)

    def update_crystal_button(self, attr, old, new):
        self._crystals = sorted(
            list(self.variable_selection[self.pressure][
                self.temperature].keys()))

        self._crystal_button.labels = self._crystals
        self._crystal_button.active = 0
        self.update_index_button(None, None, None)

    def update_index_button(self, attr, old, new):
        self._iter_index = sorted(
            list(self.variable_selection[self.pressure][self.temperature][
                self.crystal].keys()))
        self._iter_index_button.options = self._iter_index
        self._iter_index_button.value = self._iter_index[0]
        self.update_current_trajectory(None, None, None)

    def create_files_interface(self) -> None:
        directory_name = Div(
            text=f"<b>Current Directory:</b><br/>{self.directory}",
            width=self.controls_width,
        )
        self._filename_div = Div(text="", width=self.controls_width)
        current_file = self.get_selected_file()
        if current_file is not None:
            self._filename_div.text = f"<b>Current File:</b><br/>{current_file.name}"
            file_selection = column(
                directory_name,
                self._filename_div,
                Div(text="<b>Pressure:</b>"),
                self._pressure_button,
                Div(text="<b>Temperature:</b>"),
                self._temperature_button,
                Div(text="<b>Crystal Structure:</b>"),
                self._crystal_button,
                Div(text="<b>Iteration Index:</b>"),
                self._iter_index_button,
            )
        return file_selection

    def get_selected_file(self) -> Optional[Path]:
        if self.pressure is None:
            return None
        if self.temperature is None:
            return None
        return self.variable_selection[self.pressure][self.temperature][
            self.crystal][self.iter_index]

    def update_frame(self, attr, old, new) -> None:
        self._frame = HoomdFrame(self._trajectory[self.index])
        self.update_data(None, None, None)

    def radio_update_frame(self, attr) -> None:
        self.update_frame(attr, None, None)

    @property
    def index(self) -> int:
        try:
            return self._trajectory_slider.value
        except AttributeError:
            return 0

    def initialise_trajectory_interface(self) -> None:
        logger.debug("Loading Models: %s", self.order_functions.keys())
        self._order_parameter = RadioButtonGroup(
            name="Classification algorithm:",
            labels=list(self.order_functions.keys()),
            active=0,
            width=self.controls_width,
        )
        self._order_parameter.on_click(self.radio_update_frame)

    def create_trajectory_interface(self) -> None:
        return column(
            Div(text="<b>Classification Algorithm:<b>"),
            self._order_parameter,
            Div(text="<hr/>", width=self.controls_width, height=10),
            height=120,
        )

    def update_current_trajectory(self, attr, old, new) -> None:
        if self.get_selected_file() is not None:
            logger.debug("Opening %s", self.get_selected_file())
            self._trajectory = gsd.hoomd.open(str(self.get_selected_file()),
                                              "rb")
            num_frames = len(self._trajectory)

            try:
                if self._trajectory_slider.value > num_frames:
                    self._trajectory_slider.value = num_frames - 1
                self._trajectory_slider.end = len(self._trajectory) - 1
            except AttributeError:
                pass

            self.update_frame(attr, old, new)
            current_file = self.get_selected_file()
            if current_file is not None:
                self._filename_div.text = (
                    f"<b>Current File:</b><br/>{current_file.name}")
            else:
                self._filename_div.text = f"<b>Current File:</b><br/>None"

    def initialise_media_interface(self) -> None:
        self._trajectory_slider = Slider(
            title="Trajectory Index",
            value=0,
            start=0,
            end=max(len(self._trajectory), 1),
            step=1,
            width=self.controls_width,
        )
        self._trajectory_slider.on_change("value", self.update_frame)

        self._play_pause = Toggle(name="Play/Pause",
                                  label="Play/Pause",
                                  width=int(self.controls_width / 3))
        self._play_pause.on_click(self._play_pause_toggle)
        self._nextFrame = Button(label="Next",
                                 width=int(self.controls_width / 3))
        self._nextFrame.on_click(self._incr_index)
        self._prevFrame = Button(label="Previous",
                                 width=int(self.controls_width / 3))
        self._prevFrame.on_click(self._decr_index)
        self._increment_size = Slider(
            title="Increment Size",
            value=10,
            start=1,
            end=100,
            step=1,
            width=self.controls_width,
        )

    def _incr_index(self) -> None:
        if self._trajectory_slider.value < self._trajectory_slider.end:
            self._trajectory_slider.value = min(
                self._trajectory_slider.value + self._increment_size.value,
                self._trajectory_slider.end,
            )

    def _decr_index(self) -> None:
        if self._trajectory_slider.value > self._trajectory_slider.start:
            self._trajectory_slider.value = max(
                self._trajectory_slider.value - self._increment_size.value,
                self._trajectory_slider.start,
            )

    def create_media_interface(self):
        #  return widgetbox([prevFrame, play_pause, nextFrame, increment_size], width=300)
        return column(
            Div(text="<b>Media Controls:</b>"),
            self._trajectory_slider,
            row(
                [self._prevFrame, self._play_pause, self._nextFrame],
                width=int(self.controls_width),
            ),
            self._increment_size,
        )
        # When using webgl as the backend the save option doesn't work for some reason.

    def _update_source(self, data):
        logger.debug("Data Keys: %s", data.keys())
        self._source.data = data

    def get_order_function(self) -> Optional[Callable]:
        return self.order_functions[list(
            self.order_functions.keys())[self._order_parameter.active]]

    def update_data(self, attr, old, new):
        if self.plot and self._frame is not None:
            self.plot.title.text = f"Timestep {self._frame.timestep:,}"
        if self._frame is not None:
            data = frame2data(self._frame,
                              order_function=self.get_order_function(),
                              molecule=Trimer())
            self._update_source(data)

    def update_data_attr(self, attr):
        self.update_data(attr, None, None)

    def _play_pause_toggle(self, attr):
        if self._playing:
            self._doc.remove_periodic_callback(self._callback)
            self._playing = False
        else:
            self._callback = self._doc.add_periodic_callback(
                self._incr_index, 100)
            self._playing = True

    @staticmethod
    def create_legend():
        cm_orient = LinearColorMapper(palette=DARK_COLOURS,
                                      low=-np.pi,
                                      high=np.pi)
        cm_class = LinearColorMapper(
            palette=[hpluv_to_hex((0, 0, 60)),
                     hpluv_to_hex((0, 0, 80))],
            low=0,
            high=2)

        plot = figure(width=200, height=250)
        plot.toolbar_location = None
        plot.border_fill_color = "#FFFFFF"
        plot.outline_line_alpha = 0
        cb_orient = ColorBar(
            title="Orientation",
            major_label_text_font_size="10pt",
            title_text_font_style="bold",
            color_mapper=cm_orient,
            orientation="horizontal",
            ticker=FixedTicker(ticks=[-np.pi, 0, np.pi]),
            major_label_overrides={
                -np.pi: "-π",
                0: "0",
                np.pi: "π"
            },
            width=100,
            major_tick_line_color=None,
            location=(0, 120),
        )
        cb_class = ColorBar(
            color_mapper=cm_class,
            title="Classification",
            major_label_text_font_size="10pt",
            title_text_font_style="bold",
            orientation="vertical",
            ticker=FixedTicker(ticks=[0.5, 1.5]),
            major_label_overrides={
                0.5: "Crystal",
                1.5: "Liquid"
            },
            label_standoff=15,
            major_tick_line_color=None,
            width=20,
            height=80,
            location=(0, 0),
        )
        plot.add_layout(cb_orient)
        plot.add_layout(cb_class)
        return plot

    def initialise_doc(self):
        self.plot = figure(
            width=920,
            height=800,
            aspect_scale=1,
            match_aspect=True,
            title=f"Timestep {0:.5g}",
            output_backend="webgl",
            active_scroll="wheel_zoom",
        )
        self.plot.xgrid.grid_line_color = None
        self.plot.ygrid.grid_line_color = None
        self.plot.x_range.start = -30
        self.plot.x_range.end = 30
        self.plot.y_range.start = -30
        self.plot.y_range.end = 30
        plot_circles(self.plot, self._source)

    def create_doc(self):
        self.update_data(None, None, None)
        controls = column(
            [
                self.create_files_interface(),
                self.create_trajectory_interface(),
                self.create_media_interface(),
            ],
            width=int(self.controls_width * 1.1),
        )
        self._doc.add_root(row(controls, self.plot, self.create_legend()))
        self._doc.title = "Configurations"
示例#30
0
    global p, patches, colors, counter

    for _ in range(slider.value):
        counter += 1
        data = patches.data_source.data.copy()
        rates = np.random.uniform(0, 100, size=100).tolist()
        color = [colors[2 + int(rate / 16.667)] for rate in rates]

        p.title = 'Algorithms Deployed, Iteration: {}'.format(counter)
        source.data['rate'] = rates
        source.data['color'] = color
        time.sleep(5)


toggle = Toggle(label='START')
toggle.on_click(run)

slider = Slider(name='N iterations to advance',
                title='N iterations to advance',
                start=5,
                end=10000,
                step=5,
                value=500)

# set up layout
toggler = HBox(toggle)
inputs = VBox(toggler, slider)

# add to document
curdoc().add_root(HBox(inputs))
示例#31
0
def create():
    det_data = {}
    roi_selection = {}

    upload_div = Div(text="Open .cami file:")

    def upload_button_callback(_attr, _old, new):
        with io.StringIO(base64.b64decode(new).decode()) as file:
            h5meta_list = pyzebra.parse_h5meta(file)
            file_list = h5meta_list["filelist"]
            filelist.options = file_list
            filelist.value = file_list[0]

    upload_button = FileInput(accept=".cami")
    upload_button.on_change("value", upload_button_callback)

    def update_image(index=None):
        if index is None:
            index = index_spinner.value

        current_image = det_data["data"][index]
        proj_v_line_source.data.update(x=np.arange(0, IMAGE_W) + 0.5,
                                       y=np.mean(current_image, axis=0))
        proj_h_line_source.data.update(x=np.mean(current_image, axis=1),
                                       y=np.arange(0, IMAGE_H) + 0.5)

        image_source.data.update(
            h=[np.zeros((1, 1))],
            k=[np.zeros((1, 1))],
            l=[np.zeros((1, 1))],
        )
        image_source.data.update(image=[current_image])

        if auto_toggle.active:
            im_max = int(np.max(current_image))
            im_min = int(np.min(current_image))

            display_min_spinner.value = im_min
            display_max_spinner.value = im_max

            image_glyph.color_mapper.low = im_min
            image_glyph.color_mapper.high = im_max

    def update_overview_plot():
        h5_data = det_data["data"]
        n_im, n_y, n_x = h5_data.shape
        overview_x = np.mean(h5_data, axis=1)
        overview_y = np.mean(h5_data, axis=2)

        overview_plot_x_image_source.data.update(image=[overview_x], dw=[n_x])
        overview_plot_y_image_source.data.update(image=[overview_y], dw=[n_y])

        if frame_button_group.active == 0:  # Frame
            overview_plot_x.axis[1].axis_label = "Frame"
            overview_plot_y.axis[1].axis_label = "Frame"

            overview_plot_x_image_source.data.update(y=[0], dh=[n_im])
            overview_plot_y_image_source.data.update(y=[0], dh=[n_im])

        elif frame_button_group.active == 1:  # Omega
            overview_plot_x.axis[1].axis_label = "Omega"
            overview_plot_y.axis[1].axis_label = "Omega"

            om = det_data["rot_angle"]
            om_start = om[0]
            om_end = (om[-1] - om[0]) * n_im / (n_im - 1)
            overview_plot_x_image_source.data.update(y=[om_start], dh=[om_end])
            overview_plot_y_image_source.data.update(y=[om_start], dh=[om_end])

    def filelist_callback(_attr, _old, new):
        nonlocal det_data
        det_data = pyzebra.read_detector_data(new)

        index_spinner.value = 0
        index_spinner.high = det_data["data"].shape[0] - 1
        update_image(0)
        update_overview_plot()

    filelist = Select()
    filelist.on_change("value", filelist_callback)

    def index_spinner_callback(_attr, _old, new):
        update_image(new)

    index_spinner = Spinner(title="Image index:", value=0, low=0)
    index_spinner.on_change("value", index_spinner_callback)

    plot = Plot(
        x_range=Range1d(0, IMAGE_W, bounds=(0, IMAGE_W)),
        y_range=Range1d(0, IMAGE_H, bounds=(0, IMAGE_H)),
        plot_height=IMAGE_H * 3,
        plot_width=IMAGE_W * 3,
        toolbar_location="left",
    )

    # ---- tools
    plot.toolbar.logo = None

    # ---- axes
    plot.add_layout(LinearAxis(), place="above")
    plot.add_layout(LinearAxis(major_label_orientation="vertical"),
                    place="right")

    # ---- grid lines
    plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- rgba image glyph
    image_source = ColumnDataSource(
        dict(
            image=[np.zeros((IMAGE_H, IMAGE_W), dtype="float32")],
            h=[np.zeros((1, 1))],
            k=[np.zeros((1, 1))],
            l=[np.zeros((1, 1))],
            x=[0],
            y=[0],
            dw=[IMAGE_W],
            dh=[IMAGE_H],
        ))

    h_glyph = Image(image="h", x="x", y="y", dw="dw", dh="dh", global_alpha=0)
    k_glyph = Image(image="k", x="x", y="y", dw="dw", dh="dh", global_alpha=0)
    l_glyph = Image(image="l", x="x", y="y", dw="dw", dh="dh", global_alpha=0)

    plot.add_glyph(image_source, h_glyph)
    plot.add_glyph(image_source, k_glyph)
    plot.add_glyph(image_source, l_glyph)

    image_glyph = Image(image="image", x="x", y="y", dw="dw", dh="dh")
    plot.add_glyph(image_source, image_glyph, name="image_glyph")

    # ---- projections
    proj_v = Plot(
        x_range=plot.x_range,
        y_range=DataRange1d(),
        plot_height=200,
        plot_width=IMAGE_W * 3,
        toolbar_location=None,
    )

    proj_v.add_layout(LinearAxis(major_label_orientation="vertical"),
                      place="right")
    proj_v.add_layout(LinearAxis(major_label_text_font_size="0pt"),
                      place="below")

    proj_v.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    proj_v.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    proj_v_line_source = ColumnDataSource(dict(x=[], y=[]))
    proj_v.add_glyph(proj_v_line_source,
                     Line(x="x", y="y", line_color="steelblue"))

    proj_h = Plot(
        x_range=DataRange1d(),
        y_range=plot.y_range,
        plot_height=IMAGE_H * 3,
        plot_width=200,
        toolbar_location=None,
    )

    proj_h.add_layout(LinearAxis(), place="above")
    proj_h.add_layout(LinearAxis(major_label_text_font_size="0pt"),
                      place="left")

    proj_h.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    proj_h.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    proj_h_line_source = ColumnDataSource(dict(x=[], y=[]))
    proj_h.add_glyph(proj_h_line_source,
                     Line(x="x", y="y", line_color="steelblue"))

    # add tools
    hovertool = HoverTool(tooltips=[("intensity",
                                     "@image"), ("h", "@h"), ("k",
                                                              "@k"), ("l",
                                                                      "@l")])

    box_edit_source = ColumnDataSource(dict(x=[], y=[], width=[], height=[]))
    box_edit_glyph = Rect(x="x",
                          y="y",
                          width="width",
                          height="height",
                          fill_alpha=0,
                          line_color="red")
    box_edit_renderer = plot.add_glyph(box_edit_source, box_edit_glyph)
    boxedittool = BoxEditTool(renderers=[box_edit_renderer], num_objects=1)

    def box_edit_callback(_attr, _old, new):
        if new["x"]:
            h5_data = det_data["data"]
            x_val = np.arange(h5_data.shape[0])
            left = int(np.floor(new["x"][0]))
            right = int(np.ceil(new["x"][0] + new["width"][0]))
            bottom = int(np.floor(new["y"][0]))
            top = int(np.ceil(new["y"][0] + new["height"][0]))
            y_val = np.sum(h5_data[:, bottom:top, left:right], axis=(1, 2))
        else:
            x_val = []
            y_val = []

        roi_avg_plot_line_source.data.update(x=x_val, y=y_val)

    box_edit_source.on_change("data", box_edit_callback)

    wheelzoomtool = WheelZoomTool(maintain_focus=False)
    plot.add_tools(
        PanTool(),
        BoxZoomTool(),
        wheelzoomtool,
        ResetTool(),
        hovertool,
        boxedittool,
    )
    plot.toolbar.active_scroll = wheelzoomtool

    # shared frame range
    frame_range = DataRange1d()
    det_x_range = DataRange1d()
    overview_plot_x = Plot(
        title=Title(text="Projections on X-axis"),
        x_range=det_x_range,
        y_range=frame_range,
        plot_height=400,
        plot_width=400,
        toolbar_location="left",
    )

    # ---- tools
    wheelzoomtool = WheelZoomTool(maintain_focus=False)
    overview_plot_x.toolbar.logo = None
    overview_plot_x.add_tools(
        PanTool(),
        BoxZoomTool(),
        wheelzoomtool,
        ResetTool(),
    )
    overview_plot_x.toolbar.active_scroll = wheelzoomtool

    # ---- axes
    overview_plot_x.add_layout(LinearAxis(axis_label="Coordinate X, pix"),
                               place="below")
    overview_plot_x.add_layout(LinearAxis(axis_label="Frame",
                                          major_label_orientation="vertical"),
                               place="left")

    # ---- grid lines
    overview_plot_x.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    overview_plot_x.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- rgba image glyph
    overview_plot_x_image_source = ColumnDataSource(
        dict(image=[np.zeros((1, 1), dtype="float32")],
             x=[0],
             y=[0],
             dw=[1],
             dh=[1]))

    overview_plot_x_image_glyph = Image(image="image",
                                        x="x",
                                        y="y",
                                        dw="dw",
                                        dh="dh")
    overview_plot_x.add_glyph(overview_plot_x_image_source,
                              overview_plot_x_image_glyph,
                              name="image_glyph")

    det_y_range = DataRange1d()
    overview_plot_y = Plot(
        title=Title(text="Projections on Y-axis"),
        x_range=det_y_range,
        y_range=frame_range,
        plot_height=400,
        plot_width=400,
        toolbar_location="left",
    )

    # ---- tools
    wheelzoomtool = WheelZoomTool(maintain_focus=False)
    overview_plot_y.toolbar.logo = None
    overview_plot_y.add_tools(
        PanTool(),
        BoxZoomTool(),
        wheelzoomtool,
        ResetTool(),
    )
    overview_plot_y.toolbar.active_scroll = wheelzoomtool

    # ---- axes
    overview_plot_y.add_layout(LinearAxis(axis_label="Coordinate Y, pix"),
                               place="below")
    overview_plot_y.add_layout(LinearAxis(axis_label="Frame",
                                          major_label_orientation="vertical"),
                               place="left")

    # ---- grid lines
    overview_plot_y.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    overview_plot_y.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- rgba image glyph
    overview_plot_y_image_source = ColumnDataSource(
        dict(image=[np.zeros((1, 1), dtype="float32")],
             x=[0],
             y=[0],
             dw=[1],
             dh=[1]))

    overview_plot_y_image_glyph = Image(image="image",
                                        x="x",
                                        y="y",
                                        dw="dw",
                                        dh="dh")
    overview_plot_y.add_glyph(overview_plot_y_image_source,
                              overview_plot_y_image_glyph,
                              name="image_glyph")

    def frame_button_group_callback(_active):
        update_overview_plot()

    frame_button_group = RadioButtonGroup(labels=["Frames", "Omega"], active=0)
    frame_button_group.on_click(frame_button_group_callback)

    roi_avg_plot = Plot(
        x_range=DataRange1d(),
        y_range=DataRange1d(),
        plot_height=IMAGE_H * 3,
        plot_width=IMAGE_W * 3,
        toolbar_location="left",
    )

    # ---- tools
    roi_avg_plot.toolbar.logo = None

    # ---- axes
    roi_avg_plot.add_layout(LinearAxis(), place="below")
    roi_avg_plot.add_layout(LinearAxis(major_label_orientation="vertical"),
                            place="left")

    # ---- grid lines
    roi_avg_plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    roi_avg_plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    roi_avg_plot_line_source = ColumnDataSource(dict(x=[], y=[]))
    roi_avg_plot.add_glyph(roi_avg_plot_line_source,
                           Line(x="x", y="y", line_color="steelblue"))

    cmap_dict = {
        "gray": Greys256,
        "gray_reversed": Greys256[::-1],
        "plasma": Plasma256,
        "cividis": Cividis256,
    }

    def colormap_callback(_attr, _old, new):
        image_glyph.color_mapper = LinearColorMapper(palette=cmap_dict[new])
        overview_plot_x_image_glyph.color_mapper = LinearColorMapper(
            palette=cmap_dict[new])
        overview_plot_y_image_glyph.color_mapper = LinearColorMapper(
            palette=cmap_dict[new])

    colormap = Select(title="Colormap:", options=list(cmap_dict.keys()))
    colormap.on_change("value", colormap_callback)
    colormap.value = "plasma"

    radio_button_group = RadioButtonGroup(labels=["nb", "nb_bi"], active=0)

    STEP = 1

    # ---- colormap auto toggle button
    def auto_toggle_callback(state):
        if state:
            display_min_spinner.disabled = True
            display_max_spinner.disabled = True
        else:
            display_min_spinner.disabled = False
            display_max_spinner.disabled = False

        update_image()

    auto_toggle = Toggle(label="Auto Range",
                         active=True,
                         button_type="default")
    auto_toggle.on_click(auto_toggle_callback)

    # ---- colormap display max value
    def display_max_spinner_callback(_attr, _old_value, new_value):
        display_min_spinner.high = new_value - STEP
        image_glyph.color_mapper.high = new_value

    display_max_spinner = Spinner(
        title="Maximal Display Value:",
        low=0 + STEP,
        value=1,
        step=STEP,
        disabled=auto_toggle.active,
    )
    display_max_spinner.on_change("value", display_max_spinner_callback)

    # ---- colormap display min value
    def display_min_spinner_callback(_attr, _old_value, new_value):
        display_max_spinner.low = new_value + STEP
        image_glyph.color_mapper.low = new_value

    display_min_spinner = Spinner(
        title="Minimal Display Value:",
        high=1 - STEP,
        value=0,
        step=STEP,
        disabled=auto_toggle.active,
    )
    display_min_spinner.on_change("value", display_min_spinner_callback)

    def hkl_button_callback():
        index = index_spinner.value
        setup_type = "nb_bi" if radio_button_group.active else "nb"
        h, k, l = calculate_hkl(det_data, index, setup_type)
        image_source.data.update(h=[h], k=[k], l=[l])

    hkl_button = Button(label="Calculate hkl (slow)")
    hkl_button.on_click(hkl_button_callback)

    selection_list = TextAreaInput(rows=7)

    def selection_button_callback():
        nonlocal roi_selection
        selection = [
            int(np.floor(det_x_range.start)),
            int(np.ceil(det_x_range.end)),
            int(np.floor(det_y_range.start)),
            int(np.ceil(det_y_range.end)),
            int(np.floor(frame_range.start)),
            int(np.ceil(frame_range.end)),
        ]

        filename_id = filelist.value[-8:-4]
        if filename_id in roi_selection:
            roi_selection[f"{filename_id}"].append(selection)
        else:
            roi_selection[f"{filename_id}"] = [selection]

        selection_list.value = str(roi_selection)

    selection_button = Button(label="Add selection")
    selection_button.on_click(selection_button_callback)

    # Final layout
    layout_image = column(
        gridplot([[proj_v, None], [plot, proj_h]], merge_tools=False),
        row(index_spinner))
    colormap_layout = column(colormap, auto_toggle, display_max_spinner,
                             display_min_spinner)
    hkl_layout = column(radio_button_group, hkl_button)

    layout_overview = column(
        gridplot(
            [[overview_plot_x, overview_plot_y]],
            toolbar_options=dict(logo=None),
            merge_tools=True,
        ),
        frame_button_group,
    )

    tab_layout = row(
        column(
            upload_div,
            upload_button,
            filelist,
            layout_image,
            row(colormap_layout, hkl_layout),
        ),
        column(
            roi_avg_plot,
            layout_overview,
            row(selection_button, selection_list),
        ),
    )

    return Panel(child=tab_layout, title="Data Viewer")
示例#32
0
color_bar = ColorBar(color_mapper=color_mapper, 
                     label_standoff=8, border_line_color=None, location=(0,0))

img_obj = laser_fig.image(name='img',image='image', x='x', y='y', dw='dw', dh='dh', source=source, color_mapper=color_mapper)
laser_fig.xaxis.axis_label = 'x (mm)'
laser_fig.yaxis.axis_label = 'y (mm)'
laser_fig.add_layout(color_bar, 'right')

wavelength_txt_box = PVTextBox(f'{prefix}laser:wavelength', pvdb['input'][f'{prefix}laser:wavelength'])

def laser_wall_power(value):
    caput('laser_on', int(value))

# Power button
laser_on_button = Toggle(label="Laser Power", button_type="success")
laser_on_button.on_click(laser_wall_power)

#-------------------------------------------------------
# BEAM PVS AND SET UP
beamline_pv_sliders = []
for name in ['gun:voltage', 'sol1:current', 'sol2:current']:
    lolim = pvdb['input'][f'{prefix}{name}']['lolim']
    hilim = pvdb['input'][f'{prefix}{name}']['hilim']
    units = pvdb['input'][f'{prefix}{name}']['unit']

    beamline_pv_sliders.append(PVSlider(f'{name} ({units})',  f'{prefix}{name}', 1.0, lolim, hilim, 100))

beamline_sliders = [pv_slider.slider for pv_slider in beamline_pv_sliders]

beam_pvs = { pv.replace(prefix,''):PV(f'{pv}',auto_monitor=True) for pv in pvdb['output'] if 'beam' in pv }
        else:
            exec(
                (s + '=\"' + strings[s][lang] + '\"').encode(encoding='utf-8'))


# Slider to change location of Forces F1 and F2
F1F2Location_slider = LatexSlider(value=20,
                                  start=1,
                                  end=39,
                                  step=1,
                                  value_unit="\\text{m}")
F1F2Location_slider.on_change('value', changeLength)

# Toggle button to show forces
show_button = Toggle(button_type="success")
show_button.on_click(changeShow)

lang_button = Button(button_type="success")
lang_button.on_click(changeLanguage)

# Description from HTML file
description_filename = join(dirname(__file__), "description.html")
description = LatexDiv(render_as_text=False, width=880)

# Set language
setDocumentLanguage(std_lang)

curdoc().add_root(
    column(row(Spacer(width=600), lang_button), description,
           row(plot, column(F1F2Location_slider, show_button, value_plot))))
# curdoc().title = strings["curdoc().title"]["en"]