Пример #1
0
def modify_doc(doc):
    h_input = Div(text="""<h2>Upload your mp4 file</h2> """, max_height=40)

    file_input = FileInput()
    source = ColumnDataSource(dict())

    def value_changed(attr, old, new):
        data = b64decode(new)
        print(type(data))
        with open("fake_video.mp4", "wb") as binary_file:
            # Write bytes to file
            binary_file.write(data)

    file_input.on_change("value", value_changed)
    doc.add_root(column(h_input, file_input))
Пример #2
0
    def setup_bokeh_server(self, doc):
        """
        Setup the bokeh server in the mainwindow.py.  The server
        must be started on the main thread.

        Use the doc given to create a layout.
        Also create a callback to update the plot to
        view the live data.

        :param doc: Doc used to display the data to the webpage
        :return:
        """
        #self.create_bokeh_plots()

        #plot_layout_dash = layout([
        #    [self.plot_range],
        #    [self.plot_earth_east, self.plot_earth_north],
        #    [self.plot_mag, self.plot_dir]
        #], sizing_mode='stretch_both')

        #plot_layout_profile = layout([
        #    [self.plot_amp]
        #], sizing_mode='stretch_both')

        # Create tabs
        #tab1 = Panel(child=plot_layout_dash, title="Dashboard")
        #tab2 = Panel(child=plot_layout_profile, title="Profile")
        #tabs = Tabs(tabs=[tab1, tab2])

        file_input = FileInput(accept=".ens,.bin,.rtf")
        file_input.on_change("value", self.file_input_handler)

        ol = layout([file_input])
        tab1 = Panel(child=ol, title="Playback")
        tabs = Tabs(tabs=[tab1])

        # Document to display
        doc.add_root(tabs)

        # Callback toupdate the plot
        callback_rate = 2500
        doc.add_periodic_callback(self.update_live_plot, callback_rate)

        doc.title = "ADCP Dashboard"
Пример #3
0
        def modify_doc(doc):
            h_input = Div(text="""<h2>Upload your mp4 file</h2> """, max_height=40)

            file_input = FileInput()
            source = ColumnDataSource(dict())

            def value_changed(attr, old, new):
                # print(type(data))
                # with open("fake_video.mp4", "wb") as binary_file:
                # Write bytes to file
                # binary_file.write(data)
                print("init new process")
                # data = b64decode(new)
                print(file_input.filename)
                self.spawn_new_worker(file_input.filename)
                print(self.pool)

            file_input.on_change("filename", value_changed)
            doc.add_root(column(h_input, file_input))
Пример #4
0
def modify_doc(doc):
    h_input = Div(text="""<h2>Upload your mp4 file</h2> """, max_height=40)

    file_input = FileInput()
    source = ColumnDataSource(dict())

    def value_changed(attr, old, new):
        data = b64decode(new)
        print(type(data))
        with open("fake_video.mp4", "wb") as binary_file:
            # Write bytes to file
            # binary_file.write(data)
            print("trigger os popen here")

            cmd = "echo 'run command for stream debugger'"
            output = subprocess.Popen(
                cmd,
                shell=True,
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                universal_newlines=True,
            )

            exit_code = output.wait()
            print(exit_code)
            result = ""
            if exit_code != 0:
                print("run failed")
                for line in output.stderr:
                    result = result + line
            else:
                print("run successfully")
                for line in output.stdout:
                    print(f"line: {line}")
                    result = result + line
            print(f"result: {result}")

    file_input.on_change("value", value_changed)
    doc.add_root(column(h_input, file_input))
Пример #5
0
# plot-figure visual settings
img_path = "https://raw.githubusercontent.com/sanam407/Simulator/master/AstaZero.png"
plot.image_url(url=[img_path], x=-62, y=35, w=124, h=70)
plot.x_range.bounds = (-62, 62)
plot.y_range.bounds = (-35, 35)
plot.background_fill_color = "#252e38"
plot.border_fill_color = "#252e38"
plot.grid.grid_line_color = None
plot.axis.axis_label = None
plot.axis.visible = False
plot.outline_line_color = '#41454a'
plot.title.align = "center"
plot.title.text_color = "white"
plot.title.text_font_size = "15px"

# Adding plot-figure to root of current html document
curdoc().add_root(plot)

# Importing JSON files
file_input = FileInput(accept=".json",
                       name="top")  # "Choose JSOn file exported from Vista"
file_input.on_change("value", show_plot)
curdoc().add_root(file_input)

# Column layout setting
left_panel = column(Div(text="No File Selected"), width=400, name="left")
curdoc().add_root(left_panel)

# Title of current document
curdoc().title = 'Vista Reachability Analysis'
Пример #6
0
    def __init__(self, m):
        if debug: print("Initializing new View object...")
        self.curdoc = curdoc  # reference to the current Bokeh document
        self.m = m
        self.firstrun = True

        self.subject_select = Select(title="Subjects:",
                                     value=sorted_xs[0],
                                     options=sorted_xs,
                                     width=200)
        self.model_select = Select(title="Model:",
                                   value=selected_model,
                                   options=stored_models,
                                   width=200)
        self.slice_slider_frontal = Slider(start=1,
                                           end=m.subj_bg.shape[2],
                                           value=50,
                                           step=1,
                                           title="Coronal slice",
                                           width=200)
        self.slice_slider_axial = Slider(start=1,
                                         end=m.subj_bg.shape[0],
                                         value=50,
                                         step=1,
                                         title="Axial slice",
                                         width=200)
        self.slice_slider_sagittal = Slider(start=1,
                                            end=m.subj_bg.shape[1],
                                            value=50,
                                            step=1,
                                            title="Sagittal slice",
                                            width=200)
        self.threshold_slider = Slider(start=0,
                                       end=1,
                                       value=0.4,
                                       step=0.05,
                                       title="Relevance threshold",
                                       width=200)
        self.clustersize_slider = Slider(start=0,
                                         end=250,
                                         value=50,
                                         step=10,
                                         title="Minimum cluster size",
                                         width=200)
        self.transparency_slider = Slider(start=0,
                                          end=1,
                                          value=0.3,
                                          step=0.05,
                                          title="Overlay transparency",
                                          width=200)

        # Initialize the figures:
        self.guide_frontal = figure(plot_width=208,
                                    plot_height=70,
                                    title='Relevance>threshold per slice:',
                                    toolbar_location=None,
                                    active_drag=None,
                                    active_inspect=None,
                                    active_scroll=None,
                                    active_tap=None)
        self.guide_frontal.title.text_font = 'arial'
        self.guide_frontal.title.text_font_style = 'normal'
        # guide_frontal.title.text_font_size = '10pt'
        self.guide_frontal.axis.visible = False
        self.guide_frontal.x_range.range_padding = 0
        self.guide_frontal.y_range.range_padding = 0

        self.guide_axial = figure(plot_width=208,
                                  plot_height=70,
                                  title='Relevance>threshold per slice:',
                                  toolbar_location=None,
                                  active_drag=None,
                                  active_inspect=None,
                                  active_scroll=None,
                                  active_tap=None)
        self.guide_axial.title.text_font = 'arial'
        self.guide_axial.title.text_font_style = 'normal'
        # guide_axial.title.text_font_size = '10pt'
        self.guide_axial.axis.visible = False
        self.guide_axial.x_range.range_padding = 0
        self.guide_axial.y_range.range_padding = 0

        self.guide_sagittal = figure(plot_width=208,
                                     plot_height=70,
                                     title='Relevance>threshold per slice:',
                                     toolbar_location=None,
                                     active_drag=None,
                                     active_inspect=None,
                                     active_scroll=None,
                                     active_tap=None)
        self.guide_sagittal.title.text_font = 'arial'
        self.guide_sagittal.title.text_font_style = 'normal'
        # guide_sagittal.title.text_font_size = '10pt'
        self.guide_sagittal.axis.visible = False
        self.guide_sagittal.x_range.range_padding = 0
        self.guide_sagittal.y_range.range_padding = 0

        self.clusthist = figure(plot_width=208,
                                plot_height=70,
                                title='Distribution of cluster sizes:',
                                toolbar_location=None,
                                active_drag=None,
                                active_inspect=None,
                                active_scroll=None,
                                active_tap=None)
        self.clusthist.title.text_font = 'arial'
        self.clusthist.title.text_font_style = 'normal'
        self.clusthist.axis.visible = False
        self.clusthist.x_range.range_padding = 0
        self.clusthist.y_range.range_padding = 0

        self.p_frontal = figure(
            plot_width=int(np.floor(m.subj_bg.shape[1] * scale_factor)),
            plot_height=int(np.floor(m.subj_bg.shape[0] * scale_factor)),
            title='',
            toolbar_location=None,
            active_drag=None,
            active_inspect=None,
            active_scroll=None,
            active_tap=None)
        self.p_frontal.axis.visible = False
        self.p_frontal.x_range.range_padding = 0
        self.p_frontal.y_range.range_padding = 0

        self.flip_frontal_view = Toggle(label='Flip L/R orientation',
                                        button_type='default',
                                        width=200,
                                        active=flip_left_right_in_frontal_plot)

        self.orientation_label_shown_left = Label(
            text='R' if flip_left_right_in_frontal_plot else 'L',
            render_mode='css',
            x=3,
            y=self.m.subj_bg.shape[0] - 13,
            text_align='left',
            text_color='white',
            text_font_size='20px',
            border_line_color='white',
            border_line_alpha=0,
            background_fill_color='black',
            background_fill_alpha=0,
            level='overlay',
            visible=True)
        self.orientation_label_shown_right = Label(
            text='L' if flip_left_right_in_frontal_plot else 'R',
            render_mode='css',
            x=self.m.subj_bg.shape[1] - 3,
            y=self.m.subj_bg.shape[0] - 13,
            text_align='right',
            text_color='white',
            text_font_size='20px',
            border_line_color='white',
            border_line_alpha=0,
            background_fill_color='black',
            background_fill_alpha=0,
            level='overlay',
            visible=True)

        self.p_frontal.add_layout(self.orientation_label_shown_left, 'center')
        self.p_frontal.add_layout(self.orientation_label_shown_right, 'center')

        # The vertical crosshair line on the frontal view that indicates the selected sagittal slice.
        self.frontal_crosshair_from_sagittal = Span(
            location=self.slice_slider_sagittal.value - 1,
            dimension='height',
            line_color='green',
            line_width=1,
            render_mode="css")

        # The horizontal crosshair line on the frontal view that indicates the selected axial slice.
        self.frontal_crosshair_from_axial = Span(
            location=self.slice_slider_axial.value - 1,
            dimension='width',
            line_color='green',
            line_width=1,
            render_mode="css")
        self.p_frontal.add_layout(self.frontal_crosshair_from_sagittal)
        self.p_frontal.add_layout(self.frontal_crosshair_from_axial)

        self.p_axial = figure(
            plot_width=int(np.floor(m.subj_bg.shape[1] * scale_factor)),
            plot_height=int(np.floor(m.subj_bg.shape[2] * scale_factor)),
            title='',
            toolbar_location=None,
            active_drag=None,
            active_inspect=None,
            active_scroll=None,
            active_tap=None)
        self.p_axial.axis.visible = False
        self.p_axial.x_range.range_padding = 0
        self.p_axial.y_range.range_padding = 0

        self.axial_crosshair_from_sagittal = Span(
            location=self.slice_slider_sagittal.value - 1,
            dimension='height',
            line_color='green',
            line_width=1,
            render_mode="css")
        self.axial_crosshair_from_frontal = Span(
            location=self.slice_slider_frontal.end -
            self.slice_slider_frontal.value + 1,
            dimension='width',
            line_color='green',
            line_width=1,
            render_mode="css")
        self.p_axial.add_layout(self.axial_crosshair_from_sagittal)
        self.p_axial.add_layout(self.axial_crosshair_from_frontal)

        self.p_sagittal = figure(
            plot_width=int(np.floor(m.subj_bg.shape[2] * scale_factor)),
            plot_height=int(np.floor(m.subj_bg.shape[0] * scale_factor)),
            title='',
            toolbar_location=None,
            active_drag=None,
            active_inspect=None,
            active_scroll=None,
            active_tap=None)
        self.p_sagittal.axis.visible = False
        self.p_sagittal.x_range.range_padding = 0
        self.p_sagittal.y_range.range_padding = 0

        self.sagittal_crosshair_from_frontal = Span(
            location=self.slice_slider_frontal.value - 1,
            dimension='height',
            line_color='green',
            line_width=1,
            render_mode="css")
        self.sagittal_crosshair_from_axial = Span(
            location=self.slice_slider_axial.end -
            self.slice_slider_axial.value - 1,
            dimension='width',
            line_color='green',
            line_width=1,
            render_mode="css")
        self.p_sagittal.add_layout(self.sagittal_crosshair_from_frontal)
        self.p_sagittal.add_layout(self.sagittal_crosshair_from_axial)

        self.loading_label = Label(text='Processing scan...',
                                   render_mode='css',
                                   x=self.m.subj_bg.shape[1] // 2,
                                   y=self.m.subj_bg.shape[2] // 2,
                                   text_align='center',
                                   text_color='white',
                                   text_font_size='25px',
                                   text_font_style='italic',
                                   border_line_color='white',
                                   border_line_alpha=1.0,
                                   background_fill_color='black',
                                   background_fill_alpha=0.5,
                                   level='overlay',
                                   visible=False)

        self.p_axial.add_layout(self.loading_label)

        self.render_backround()

        # create empty plot objects with empty ("fully transparent") ColumnDataSources):
        self.frontal_zeros = np.zeros_like(
            np.flipud(self.bg[:, :, self.slice_slider_frontal.value - 1]))
        self.axial_zeros = np.zeros_like(
            np.rot90(self.bg[self.m.subj_bg.shape[0] -
                             self.slice_slider_axial.value, :, :]))
        self.sagittal_zeros = np.zeros_like(
            np.flipud(
                self.
                bg[:, self.slice_slider_sagittal.end -
                   self.slice_slider_sagittal.value if self.flip_frontal_view.
                   active else self.slice_slider_sagittal.value - 1, :]))
        self.frontal_zeros[True] = 255  # value for a fully transparent pixel
        self.axial_zeros[True] = 255
        self.sagittal_zeros[True] = 255

        self.frontal_data = ColumnDataSource(data=dict(
            image=[self.frontal_zeros, self.frontal_zeros, self.frontal_zeros],
            x=[0, 0, 0],
            y=[0, 0, 0]))
        self.axial_data = ColumnDataSource(data=dict(
            image=[self.axial_zeros, self.axial_zeros, self.axial_zeros],
            x=[0, 0, 0],
            y=[0, 0, 0]))
        self.sagittal_data = ColumnDataSource(data=dict(image=[
            self.sagittal_zeros, self.sagittal_zeros, self.sagittal_zeros
        ],
                                                        x=[0, 0, 0],
                                                        y=[0, 0, 0]))

        self.p_frontal.image_rgba(image="image",
                                  x="x",
                                  y="y",
                                  dw=self.frontal_zeros.shape[1],
                                  dh=self.frontal_zeros.shape[0],
                                  source=self.frontal_data)
        self.p_axial.image_rgba(image="image",
                                x="x",
                                y="y",
                                dw=self.axial_zeros.shape[1],
                                dh=self.axial_zeros.shape[0],
                                source=self.axial_data)
        self.p_sagittal.image_rgba(image="image",
                                   x="x",
                                   y="y",
                                   dw=self.sagittal_zeros.shape[1],
                                   dh=self.sagittal_zeros.shape[0],
                                   source=self.sagittal_data)
        self.toggle_transparency = Toggle(label='Hide relevance overlay',
                                          button_type='default',
                                          width=200)
        self.toggle_regions = Toggle(label='Show outline of atlas region',
                                     button_type='default',
                                     width=200)

        self.region_ID = get_region_id(
            self.slice_slider_axial.value - 1, self.slice_slider_sagittal.end -
            self.slice_slider_sagittal.value if self.flip_frontal_view.active
            else self.slice_slider_sagittal.value - 1,
            self.slice_slider_frontal.value - 1)
        self.selected_region = get_region_name(
            self.slice_slider_axial.value - 1, self.slice_slider_sagittal.end -
            self.slice_slider_sagittal.value if self.flip_frontal_view.active
            else self.slice_slider_sagittal.value - 1,
            self.slice_slider_frontal.value - 1)
        self.region_div = Div(text="Region: " + self.selected_region,
                              sizing_mode="stretch_both",
                              css_classes=["region_divs"])

        self.cluster_size_div = Div(
            text="Cluster Size: " + "0", css_classes=[
                "cluster_divs"
            ])  # initialize with 0 because clust_labelimg does not exist yet
        self.cluster_mean_div = Div(text="Mean Intensity: " + "0",
                                    css_classes=["cluster_divs"])
        self.cluster_peak_div = Div(text="Peak Intensity: " + "0",
                                    css_classes=["cluster_divs"])

        # see InteractiveVis/static/ for default formatting/style definitions
        self.age_spinner = Spinner(
            title="Age:",
            placeholder="years",
            mode="int",
            low=55,
            high=99,
            width=int(np.floor(m.subj_bg.shape[1] * scale_factor) // 2 - 10),
            disabled=True)  #no subject selected at time of initialization
        self.sex_select = Select(
            title="Sex:",
            value="N/A",
            options=["male", "female", "N/A"],
            width=int(np.floor(m.subj_bg.shape[1] * scale_factor) // 2 - 10),
            disabled=True)
        self.tiv_spinner = Spinner(
            title="TIV:",
            placeholder="cm³",
            mode="float",
            low=1000,
            high=2100,
            width=int(np.floor(m.subj_bg.shape[1] * scale_factor) // 2 - 10),
            disabled=True)
        self.field_strength_select = Select(
            title="Field Strength [T]:",
            value="1.5",
            options=["1.5", "3.0"],
            width=int(np.floor(m.subj_bg.shape[1] * scale_factor) // 2 - 10),
            disabled=True)

        # Empty dummy figure to add ColorBar to, because annotations (like a ColorBar) must have a
        # parent figure in Bokeh:
        self.p_color_bar = figure(
            plot_width=100,
            plot_height=int(np.floor(m.subj_bg.shape[0] * scale_factor)),
            title='',
            toolbar_location=None,
            active_drag=None,
            active_inspect=None,
            active_scroll=None,
            active_tap=None,
            outline_line_alpha=0.0)
        self.p_color_bar.axis.visible = False
        self.p_color_bar.x_range.range_padding = 0
        self.p_color_bar.y_range.range_padding = 0

        self.color_mapper = LinearColorMapper(palette=color_palette,
                                              low=-1,
                                              high=1)
        self.color_bar = ColorBar(color_mapper=self.color_mapper,
                                  title="Relevance")
        self.p_color_bar.add_layout(self.color_bar)
        self.scan_upload = FileInput(accept='.nii.gz, .nii')
        self.residualize_button = Button(
            label="Start residualization and view scan", disabled=True)

        def dummy():
            pass

        self.residualize_button.on_click(
            dummy
        )  # TODO: remove this once on_click is working when setting callback only from the model class (bug in Bokeh 2.2.x ?)

        # Initialize column layout:
        self.layout = row(
            column(
                self.subject_select, self.model_select,
                Spacer(height=40, width=200, sizing_mode='scale_width'),
                self.threshold_slider, self.clusthist, self.clustersize_slider,
                self.transparency_slider, self.toggle_transparency,
                self.toggle_regions, self.region_div,
                column(self.cluster_size_div, self.cluster_mean_div,
                       self.cluster_peak_div)),
            column(
                row(self.age_spinner,
                    self.sex_select,
                    self.tiv_spinner,
                    self.field_strength_select,
                    self.scan_upload,
                    css_classes=["subject_divs"]),
                row(self.residualize_button),
                row(
                    column(self.p_frontal, self.slice_slider_frontal,
                           self.guide_frontal, self.flip_frontal_view),
                    column(self.p_axial, self.slice_slider_axial,
                           self.guide_axial),
                    column(self.p_sagittal, self.slice_slider_sagittal,
                           self.guide_sagittal), column(self.p_color_bar))))

        self.clust_hist_bins = list(range(
            0, 250 + 1,
            10))  # list from (0, 10, .., 250); range max is slider_max_size+1
def create(palm):
    doc = curdoc()

    # Calibration averaged waveforms per photon energy
    waveform_plot = Plot(
        title=Title(text="eTOF calibration waveforms"),
        x_range=DataRange1d(),
        y_range=DataRange1d(),
        plot_height=760,
        plot_width=PLOT_CANVAS_WIDTH,
        toolbar_location="right",
    )

    # ---- tools
    waveform_plot.toolbar.logo = None
    waveform_plot_hovertool = HoverTool(
        tooltips=[("energy, eV", "@en"), ("eTOF bin", "$x{0.}")])

    waveform_plot.add_tools(PanTool(), BoxZoomTool(), WheelZoomTool(),
                            ResetTool(), waveform_plot_hovertool)

    # ---- axes
    waveform_plot.add_layout(LinearAxis(axis_label="eTOF time bin"),
                             place="below")
    waveform_plot.add_layout(LinearAxis(axis_label="Intensity",
                                        major_label_orientation="vertical"),
                             place="left")

    # ---- grid lines
    waveform_plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    waveform_plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- multiline glyphs
    waveform_ref_source = ColumnDataSource(dict(xs=[], ys=[], en=[]))
    waveform_ref_multiline = waveform_plot.add_glyph(
        waveform_ref_source, MultiLine(xs="xs", ys="ys", line_color="blue"))

    waveform_str_source = ColumnDataSource(dict(xs=[], ys=[], en=[]))
    waveform_str_multiline = waveform_plot.add_glyph(
        waveform_str_source, MultiLine(xs="xs", ys="ys", line_color="red"))

    # ---- legend
    waveform_plot.add_layout(
        Legend(items=[(
            "reference",
            [waveform_ref_multiline]), ("streaked",
                                        [waveform_str_multiline])]))
    waveform_plot.legend.click_policy = "hide"

    # ---- vertical spans
    photon_peak_ref_span = Span(location=0,
                                dimension="height",
                                line_dash="dashed",
                                line_color="blue")
    photon_peak_str_span = Span(location=0,
                                dimension="height",
                                line_dash="dashed",
                                line_color="red")
    waveform_plot.add_layout(photon_peak_ref_span)
    waveform_plot.add_layout(photon_peak_str_span)

    # Calibration fit plot
    fit_plot = Plot(
        title=Title(text="eTOF calibration fit"),
        x_range=DataRange1d(),
        y_range=DataRange1d(),
        plot_height=PLOT_CANVAS_HEIGHT,
        plot_width=PLOT_CANVAS_WIDTH,
        toolbar_location="right",
    )

    # ---- tools
    fit_plot.toolbar.logo = None
    fit_plot.add_tools(PanTool(), BoxZoomTool(), WheelZoomTool(), ResetTool())

    # ---- axes
    fit_plot.add_layout(LinearAxis(axis_label="Photoelectron peak shift"),
                        place="below")
    fit_plot.add_layout(LinearAxis(axis_label="Photon energy, eV",
                                   major_label_orientation="vertical"),
                        place="left")

    # ---- grid lines
    fit_plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    fit_plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- circle glyphs
    fit_ref_circle_source = ColumnDataSource(dict(x=[], y=[]))
    fit_ref_circle = fit_plot.add_glyph(
        fit_ref_circle_source, Circle(x="x", y="y", line_color="blue"))
    fit_str_circle_source = ColumnDataSource(dict(x=[], y=[]))
    fit_str_circle = fit_plot.add_glyph(fit_str_circle_source,
                                        Circle(x="x", y="y", line_color="red"))

    # ---- line glyphs
    fit_ref_line_source = ColumnDataSource(dict(x=[], y=[]))
    fit_ref_line = fit_plot.add_glyph(fit_ref_line_source,
                                      Line(x="x", y="y", line_color="blue"))
    fit_str_line_source = ColumnDataSource(dict(x=[], y=[]))
    fit_str_line = fit_plot.add_glyph(fit_str_line_source,
                                      Line(x="x", y="y", line_color="red"))

    # ---- legend
    fit_plot.add_layout(
        Legend(items=[
            ("reference", [fit_ref_circle, fit_ref_line]),
            ("streaked", [fit_str_circle, fit_str_line]),
        ]))
    fit_plot.legend.click_policy = "hide"

    # Calibration results datatables
    def datatable_ref_source_callback(_attr, _old_value, new_value):
        for en, ps, use in zip(new_value["energy"], new_value["peak_pos_ref"],
                               new_value["use_in_fit"]):
            palm.etofs["0"].calib_data.loc[
                en, "calib_tpeak"] = ps if ps != "NaN" else np.nan
            palm.etofs["0"].calib_data.loc[en, "use_in_fit"] = use

        calib_res = {}
        for etof_key in palm.etofs:
            calib_res[etof_key] = palm.etofs[etof_key].fit_calibration_curve()
        update_calibration_plot(calib_res)

    datatable_ref_source = ColumnDataSource(
        dict(energy=["", "", ""],
             peak_pos_ref=["", "", ""],
             use_in_fit=[True, True, True]))
    datatable_ref_source.on_change("data", datatable_ref_source_callback)

    datatable_ref = DataTable(
        source=datatable_ref_source,
        columns=[
            TableColumn(field="energy",
                        title="Photon Energy, eV",
                        editor=IntEditor()),
            TableColumn(field="peak_pos_ref",
                        title="Reference Peak",
                        editor=IntEditor()),
            TableColumn(field="use_in_fit",
                        title=" ",
                        editor=CheckboxEditor(),
                        width=80),
        ],
        index_position=None,
        editable=True,
        height=300,
        width=250,
    )

    def datatable_str_source_callback(_attr, _old_value, new_value):
        for en, ps, use in zip(new_value["energy"], new_value["peak_pos_str"],
                               new_value["use_in_fit"]):
            palm.etofs["1"].calib_data.loc[
                en, "calib_tpeak"] = ps if ps != "NaN" else np.nan
            palm.etofs["1"].calib_data.loc[en, "use_in_fit"] = use

        calib_res = {}
        for etof_key in palm.etofs:
            calib_res[etof_key] = palm.etofs[etof_key].fit_calibration_curve()
        update_calibration_plot(calib_res)

    datatable_str_source = ColumnDataSource(
        dict(energy=["", "", ""],
             peak_pos_str=["", "", ""],
             use_in_fit=[True, True, True]))
    datatable_str_source.on_change("data", datatable_str_source_callback)

    datatable_str = DataTable(
        source=datatable_str_source,
        columns=[
            TableColumn(field="energy",
                        title="Photon Energy, eV",
                        editor=IntEditor()),
            TableColumn(field="peak_pos_str",
                        title="Streaked Peak",
                        editor=IntEditor()),
            TableColumn(field="use_in_fit",
                        title=" ",
                        editor=CheckboxEditor(),
                        width=80),
        ],
        index_position=None,
        editable=True,
        height=350,
        width=250,
    )

    # eTOF calibration folder path text input
    def path_textinput_callback(_attr, _old_value, _new_value):
        path_periodic_update()
        update_load_dropdown_menu()

    path_textinput = TextInput(title="eTOF calibration path:",
                               value=os.path.join(os.path.expanduser("~")),
                               width=510)
    path_textinput.on_change("value", path_textinput_callback)

    # eTOF calibration eco scans dropdown
    def scans_dropdown_callback(event):
        scans_dropdown.label = event.item

    scans_dropdown = Dropdown(label="ECO scans",
                              button_type="default",
                              menu=[])
    scans_dropdown.on_click(scans_dropdown_callback)

    # ---- etof scans periodic update
    def path_periodic_update():
        new_menu = []
        if os.path.isdir(path_textinput.value):
            for entry in os.scandir(path_textinput.value):
                if entry.is_file() and entry.name.endswith(".json"):
                    new_menu.append((entry.name, entry.name))
        scans_dropdown.menu = sorted(new_menu, reverse=True)

    doc.add_periodic_callback(path_periodic_update, 5000)

    path_tab = Panel(child=column(
        path_textinput,
        scans_dropdown,
    ),
                     title="Path")

    upload_div = Div(text="Upload ECO scan (top) and all hdf5 files (bottom):")

    # ECO scan upload FileInput
    def eco_fileinput_callback(_attr, _old, new):
        with io.BytesIO(base64.b64decode(new)) as eco_scan:
            data = json.load(eco_scan)
            print(data)

    eco_fileinput = FileInput(accept=".json", disabled=True)
    eco_fileinput.on_change("value", eco_fileinput_callback)

    # HDF5 upload FileInput
    def hdf5_fileinput_callback(_attr, _old, new):
        for base64_str in new:
            with io.BytesIO(base64.b64decode(base64_str)) as hdf5_file:
                with h5py.File(hdf5_file, "r") as h5f:
                    print(h5f.keys())

    hdf5_fileinput = FileInput(accept=".hdf5,.h5",
                               multiple=True,
                               disabled=True)
    hdf5_fileinput.on_change("value", hdf5_fileinput_callback)

    upload_tab = Panel(child=column(upload_div, eco_fileinput, hdf5_fileinput),
                       title="Upload")

    # Calibrate button
    def calibrate_button_callback():
        try:
            palm.calibrate_etof_eco(eco_scan_filename=os.path.join(
                path_textinput.value, scans_dropdown.label))
        except Exception:
            palm.calibrate_etof(folder_name=path_textinput.value)

        datatable_ref_source.data.update(
            energy=palm.etofs["0"].calib_data.index.tolist(),
            peak_pos_ref=palm.etofs["0"].calib_data["calib_tpeak"].tolist(),
            use_in_fit=palm.etofs["0"].calib_data["use_in_fit"].tolist(),
        )

        datatable_str_source.data.update(
            energy=palm.etofs["0"].calib_data.index.tolist(),
            peak_pos_str=palm.etofs["1"].calib_data["calib_tpeak"].tolist(),
            use_in_fit=palm.etofs["1"].calib_data["use_in_fit"].tolist(),
        )

    def update_calibration_plot(calib_res):
        etof_ref = palm.etofs["0"]
        etof_str = palm.etofs["1"]

        shift_val = 0
        etof_ref_wf_shifted = []
        etof_str_wf_shifted = []
        for wf_ref, wf_str in zip(etof_ref.calib_data["waveform"],
                                  etof_str.calib_data["waveform"]):
            shift_val -= max(wf_ref.max(), wf_str.max())
            etof_ref_wf_shifted.append(wf_ref + shift_val)
            etof_str_wf_shifted.append(wf_str + shift_val)

        waveform_ref_source.data.update(
            xs=len(etof_ref.calib_data) *
            [list(range(etof_ref.internal_time_bins))],
            ys=etof_ref_wf_shifted,
            en=etof_ref.calib_data.index.tolist(),
        )

        waveform_str_source.data.update(
            xs=len(etof_str.calib_data) *
            [list(range(etof_str.internal_time_bins))],
            ys=etof_str_wf_shifted,
            en=etof_str.calib_data.index.tolist(),
        )

        photon_peak_ref_span.location = etof_ref.calib_t0
        photon_peak_str_span.location = etof_str.calib_t0

        def plot_fit(time, calib_a, calib_b):
            time_fit = np.linspace(np.nanmin(time), np.nanmax(time), 100)
            en_fit = (calib_a / time_fit)**2 + calib_b
            return time_fit, en_fit

        def update_plot(calib_results, circle, line):
            (a, c), x, y = calib_results
            x_fit, y_fit = plot_fit(x, a, c)
            circle.data.update(x=x, y=y)
            line.data.update(x=x_fit, y=y_fit)

        update_plot(calib_res["0"], fit_ref_circle_source, fit_ref_line_source)
        update_plot(calib_res["1"], fit_str_circle_source, fit_str_line_source)

        calib_const_div.text = f"""
        a_str = {etof_str.calib_a:.2f}<br>
        b_str = {etof_str.calib_b:.2f}<br>
        <br>
        a_ref = {etof_ref.calib_a:.2f}<br>
        b_ref = {etof_ref.calib_b:.2f}
        """

    calibrate_button = Button(label="Calibrate eTOF",
                              button_type="default",
                              width=250)
    calibrate_button.on_click(calibrate_button_callback)

    # Photon peak noise threshold value text input
    def phot_peak_noise_thr_spinner_callback(_attr, old_value, new_value):
        if new_value > 0:
            for etof in palm.etofs.values():
                etof.photon_peak_noise_thr = new_value
        else:
            phot_peak_noise_thr_spinner.value = old_value

    phot_peak_noise_thr_spinner = Spinner(title="Photon peak noise threshold:",
                                          value=1,
                                          step=0.1)
    phot_peak_noise_thr_spinner.on_change(
        "value", phot_peak_noise_thr_spinner_callback)

    # Electron peak noise threshold value text input
    def el_peak_noise_thr_spinner_callback(_attr, old_value, new_value):
        if new_value > 0:
            for etof in palm.etofs.values():
                etof.electron_peak_noise_thr = new_value
        else:
            el_peak_noise_thr_spinner.value = old_value

    el_peak_noise_thr_spinner = Spinner(title="Electron peak noise threshold:",
                                        value=10,
                                        step=0.1)
    el_peak_noise_thr_spinner.on_change("value",
                                        el_peak_noise_thr_spinner_callback)

    # Save calibration button
    def save_button_callback():
        palm.save_etof_calib(path=path_textinput.value)
        update_load_dropdown_menu()

    save_button = Button(label="Save", button_type="default", width=250)
    save_button.on_click(save_button_callback)

    # Load calibration button
    def load_dropdown_callback(event):
        new_value = event.item
        if new_value:
            palm.load_etof_calib(os.path.join(path_textinput.value, new_value))

            datatable_ref_source.data.update(
                energy=palm.etofs["0"].calib_data.index.tolist(),
                peak_pos_ref=palm.etofs["0"].calib_data["calib_tpeak"].tolist(
                ),
                use_in_fit=palm.etofs["0"].calib_data["use_in_fit"].tolist(),
            )

            datatable_str_source.data.update(
                energy=palm.etofs["0"].calib_data.index.tolist(),
                peak_pos_str=palm.etofs["1"].calib_data["calib_tpeak"].tolist(
                ),
                use_in_fit=palm.etofs["1"].calib_data["use_in_fit"].tolist(),
            )

    def update_load_dropdown_menu():
        new_menu = []
        calib_file_ext = ".palm_etof"
        if os.path.isdir(path_textinput.value):
            for entry in os.scandir(path_textinput.value):
                if entry.is_file() and entry.name.endswith((calib_file_ext)):
                    new_menu.append(
                        (entry.name[:-len(calib_file_ext)], entry.name))
            load_dropdown.button_type = "default"
            load_dropdown.menu = sorted(new_menu, reverse=True)
        else:
            load_dropdown.button_type = "danger"
            load_dropdown.menu = new_menu

    doc.add_next_tick_callback(update_load_dropdown_menu)
    doc.add_periodic_callback(update_load_dropdown_menu, 5000)

    load_dropdown = Dropdown(label="Load", menu=[], width=250)
    load_dropdown.on_click(load_dropdown_callback)

    # eTOF fitting equation
    fit_eq_div = Div(
        text="""Fitting equation:<br><br><img src="/palm/static/5euwuy.gif">"""
    )

    # Calibration constants
    calib_const_div = Div(text=f"""
        a_str = {0}<br>
        b_str = {0}<br>
        <br>
        a_ref = {0}<br>
        b_ref = {0}
        """)

    # assemble
    tab_layout = column(
        row(
            column(waveform_plot, fit_plot),
            Spacer(width=30),
            column(
                Tabs(tabs=[path_tab, upload_tab]),
                calibrate_button,
                phot_peak_noise_thr_spinner,
                el_peak_noise_thr_spinner,
                row(save_button, load_dropdown),
                row(datatable_ref, datatable_str),
                calib_const_div,
                fit_eq_div,
            ),
        ))

    return Panel(child=tab_layout, title="eTOF Calibration")
Пример #8
0
def slider_handler(attr, old, new):
    # view.filters[0] = IndexFilter([new])
    # print(view.filters[0])
    img_plot.view = CDSView(source=source, filters=[IndexFilter([new])])


# make the handler itself
def file_handler(attr, old, new):
    add_image(new, source, filepicker.filename)
    ind = img_plot.view.filters[0].indices
    img_plot.view = CDSView(source=source, filters=[IndexFilter(ind)])
    # print(source.to_df().head())


# Make the file picker
filepicker = FileInput(accept="image/*", multiple=False)
filepicker.on_change('value', file_handler)

# Make the slider
slider = Slider(start=0, end=1, value=0, step=1, title="Image")
slider.on_change('value', slider_handler)
slider_handler_js = CustomJS(args=dict(source=source,
                                       view=img_plot.view,
                                       filter=img_plot.view.filters[0]),
                             code="""
        console.log('range_slider: value=' + this.value, this.toString())
        filter.indices = [this.value]
        //console.log('filter_ind: value=' + filter.indices[0])
        source.change.emit()
        view.change.emit()
        filter.change.emit()
Пример #9
0
def get_file_widget():
    # button_input = FileInput(accept=".csv,.txt")
    button_input = FileInput()

    return button_input
def create():
    doc = curdoc()
    det_data = {}
    cami_meta = {}

    def proposal_textinput_callback(_attr, _old, new):
        nonlocal cami_meta
        proposal = new.strip()
        for zebra_proposals_path in pyzebra.ZEBRA_PROPOSALS_PATHS:
            proposal_path = os.path.join(zebra_proposals_path, proposal)
            if os.path.isdir(proposal_path):
                # found it
                break
        else:
            raise ValueError(f"Can not find data for proposal '{proposal}'.")

        file_list = []
        for file in os.listdir(proposal_path):
            if file.endswith(".hdf"):
                file_list.append((os.path.join(proposal_path, file), file))
        file_select.options = file_list

        cami_meta = {}

    proposal_textinput = TextInput(title="Proposal number:", width=210)
    proposal_textinput.on_change("value", proposal_textinput_callback)

    def upload_button_callback(_attr, _old, new):
        nonlocal cami_meta
        with io.StringIO(base64.b64decode(new).decode()) as file:
            cami_meta = pyzebra.parse_h5meta(file)
            file_list = cami_meta["filelist"]
            file_select.options = [(entry, os.path.basename(entry))
                                   for entry in file_list]

    upload_div = Div(text="or upload .cami file:", margin=(5, 5, 0, 5))
    upload_button = FileInput(accept=".cami", width=200)
    upload_button.on_change("value", upload_button_callback)

    def update_image(index=None):
        if index is None:
            index = index_spinner.value

        current_image = det_data["data"][index]
        proj_v_line_source.data.update(x=np.arange(0, IMAGE_W) + 0.5,
                                       y=np.mean(current_image, axis=0))
        proj_h_line_source.data.update(x=np.mean(current_image, axis=1),
                                       y=np.arange(0, IMAGE_H) + 0.5)

        image_source.data.update(
            h=[np.zeros((1, 1))],
            k=[np.zeros((1, 1))],
            l=[np.zeros((1, 1))],
        )
        image_source.data.update(image=[current_image])

        if main_auto_checkbox.active:
            im_min = np.min(current_image)
            im_max = np.max(current_image)

            display_min_spinner.value = im_min
            display_max_spinner.value = im_max

            image_glyph.color_mapper.low = im_min
            image_glyph.color_mapper.high = im_max

        if "mf" in det_data:
            metadata_table_source.data.update(mf=[det_data["mf"][index]])
        else:
            metadata_table_source.data.update(mf=[None])

        if "temp" in det_data:
            metadata_table_source.data.update(temp=[det_data["temp"][index]])
        else:
            metadata_table_source.data.update(temp=[None])

        gamma, nu = calculate_pol(det_data, index)
        omega = np.ones((IMAGE_H, IMAGE_W)) * det_data["omega"][index]
        image_source.data.update(gamma=[gamma], nu=[nu], omega=[omega])

    def update_overview_plot():
        h5_data = det_data["data"]
        n_im, n_y, n_x = h5_data.shape
        overview_x = np.mean(h5_data, axis=1)
        overview_y = np.mean(h5_data, axis=2)

        overview_plot_x_image_source.data.update(image=[overview_x],
                                                 dw=[n_x],
                                                 dh=[n_im])
        overview_plot_y_image_source.data.update(image=[overview_y],
                                                 dw=[n_y],
                                                 dh=[n_im])

        if proj_auto_checkbox.active:
            im_min = min(np.min(overview_x), np.min(overview_y))
            im_max = max(np.max(overview_x), np.max(overview_y))

            proj_display_min_spinner.value = im_min
            proj_display_max_spinner.value = im_max

            overview_plot_x_image_glyph.color_mapper.low = im_min
            overview_plot_y_image_glyph.color_mapper.low = im_min
            overview_plot_x_image_glyph.color_mapper.high = im_max
            overview_plot_y_image_glyph.color_mapper.high = im_max

        frame_range.start = 0
        frame_range.end = n_im
        frame_range.reset_start = 0
        frame_range.reset_end = n_im
        frame_range.bounds = (0, n_im)

        scan_motor = det_data["scan_motor"]
        overview_plot_y.axis[1].axis_label = f"Scanning motor, {scan_motor}"

        var = det_data[scan_motor]
        var_start = var[0]
        var_end = var[-1] + (var[-1] - var[0]) / (n_im - 1)

        scanning_motor_range.start = var_start
        scanning_motor_range.end = var_end
        scanning_motor_range.reset_start = var_start
        scanning_motor_range.reset_end = var_end
        # handle both, ascending and descending sequences
        scanning_motor_range.bounds = (min(var_start,
                                           var_end), max(var_start, var_end))

    def file_select_callback(_attr, old, new):
        nonlocal det_data
        if not new:
            # skip empty selections
            return

        # Avoid selection of multiple indicies (via Shift+Click or Ctrl+Click)
        if len(new) > 1:
            # drop selection to the previous one
            file_select.value = old
            return

        if len(old) > 1:
            # skip unnecessary update caused by selection drop
            return

        det_data = pyzebra.read_detector_data(new[0])

        if cami_meta and "crystal" in cami_meta:
            det_data["ub"] = cami_meta["crystal"]["UB"]

        index_spinner.value = 0
        index_spinner.high = det_data["data"].shape[0] - 1
        index_slider.end = det_data["data"].shape[0] - 1

        zebra_mode = det_data["zebra_mode"]
        if zebra_mode == "nb":
            metadata_table_source.data.update(geom=["normal beam"])
        else:  # zebra_mode == "bi"
            metadata_table_source.data.update(geom=["bisecting"])

        update_image(0)
        update_overview_plot()

    file_select = MultiSelect(title="Available .hdf files:",
                              width=210,
                              height=250)
    file_select.on_change("value", file_select_callback)

    def index_callback(_attr, _old, new):
        update_image(new)

    index_slider = Slider(value=0, start=0, end=1, show_value=False, width=400)

    index_spinner = Spinner(title="Image index:", value=0, low=0, width=100)
    index_spinner.on_change("value", index_callback)

    index_slider.js_link("value_throttled", index_spinner, "value")
    index_spinner.js_link("value", index_slider, "value")

    plot = Plot(
        x_range=Range1d(0, IMAGE_W, bounds=(0, IMAGE_W)),
        y_range=Range1d(0, IMAGE_H, bounds=(0, IMAGE_H)),
        plot_height=IMAGE_PLOT_H,
        plot_width=IMAGE_PLOT_W,
        toolbar_location="left",
    )

    # ---- tools
    plot.toolbar.logo = None

    # ---- axes
    plot.add_layout(LinearAxis(), place="above")
    plot.add_layout(LinearAxis(major_label_orientation="vertical"),
                    place="right")

    # ---- grid lines
    plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- rgba image glyph
    image_source = ColumnDataSource(
        dict(
            image=[np.zeros((IMAGE_H, IMAGE_W), dtype="float32")],
            h=[np.zeros((1, 1))],
            k=[np.zeros((1, 1))],
            l=[np.zeros((1, 1))],
            gamma=[np.zeros((1, 1))],
            nu=[np.zeros((1, 1))],
            omega=[np.zeros((1, 1))],
            x=[0],
            y=[0],
            dw=[IMAGE_W],
            dh=[IMAGE_H],
        ))

    h_glyph = Image(image="h", x="x", y="y", dw="dw", dh="dh", global_alpha=0)
    k_glyph = Image(image="k", x="x", y="y", dw="dw", dh="dh", global_alpha=0)
    l_glyph = Image(image="l", x="x", y="y", dw="dw", dh="dh", global_alpha=0)
    gamma_glyph = Image(image="gamma",
                        x="x",
                        y="y",
                        dw="dw",
                        dh="dh",
                        global_alpha=0)
    nu_glyph = Image(image="nu",
                     x="x",
                     y="y",
                     dw="dw",
                     dh="dh",
                     global_alpha=0)
    omega_glyph = Image(image="omega",
                        x="x",
                        y="y",
                        dw="dw",
                        dh="dh",
                        global_alpha=0)

    plot.add_glyph(image_source, h_glyph)
    plot.add_glyph(image_source, k_glyph)
    plot.add_glyph(image_source, l_glyph)
    plot.add_glyph(image_source, gamma_glyph)
    plot.add_glyph(image_source, nu_glyph)
    plot.add_glyph(image_source, omega_glyph)

    image_glyph = Image(image="image", x="x", y="y", dw="dw", dh="dh")
    plot.add_glyph(image_source, image_glyph, name="image_glyph")

    # ---- projections
    proj_v = Plot(
        x_range=plot.x_range,
        y_range=DataRange1d(),
        plot_height=150,
        plot_width=IMAGE_PLOT_W,
        toolbar_location=None,
    )

    proj_v.add_layout(LinearAxis(major_label_orientation="vertical"),
                      place="right")
    proj_v.add_layout(LinearAxis(major_label_text_font_size="0pt"),
                      place="below")

    proj_v.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    proj_v.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    proj_v_line_source = ColumnDataSource(dict(x=[], y=[]))
    proj_v.add_glyph(proj_v_line_source,
                     Line(x="x", y="y", line_color="steelblue"))

    proj_h = Plot(
        x_range=DataRange1d(),
        y_range=plot.y_range,
        plot_height=IMAGE_PLOT_H,
        plot_width=150,
        toolbar_location=None,
    )

    proj_h.add_layout(LinearAxis(), place="above")
    proj_h.add_layout(LinearAxis(major_label_text_font_size="0pt"),
                      place="left")

    proj_h.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    proj_h.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    proj_h_line_source = ColumnDataSource(dict(x=[], y=[]))
    proj_h.add_glyph(proj_h_line_source,
                     Line(x="x", y="y", line_color="steelblue"))

    # add tools
    hovertool = HoverTool(tooltips=[
        ("intensity", "@image"),
        ("gamma", "@gamma"),
        ("nu", "@nu"),
        ("omega", "@omega"),
        ("h", "@h"),
        ("k", "@k"),
        ("l", "@l"),
    ])

    box_edit_source = ColumnDataSource(dict(x=[], y=[], width=[], height=[]))
    box_edit_glyph = Rect(x="x",
                          y="y",
                          width="width",
                          height="height",
                          fill_alpha=0,
                          line_color="red")
    box_edit_renderer = plot.add_glyph(box_edit_source, box_edit_glyph)
    boxedittool = BoxEditTool(renderers=[box_edit_renderer], num_objects=1)

    def box_edit_callback(_attr, _old, new):
        if new["x"]:
            h5_data = det_data["data"]
            x_val = np.arange(h5_data.shape[0])
            left = int(np.floor(new["x"][0]))
            right = int(np.ceil(new["x"][0] + new["width"][0]))
            bottom = int(np.floor(new["y"][0]))
            top = int(np.ceil(new["y"][0] + new["height"][0]))
            y_val = np.sum(h5_data[:, bottom:top, left:right], axis=(1, 2))
        else:
            x_val = []
            y_val = []

        roi_avg_plot_line_source.data.update(x=x_val, y=y_val)

    box_edit_source.on_change("data", box_edit_callback)

    wheelzoomtool = WheelZoomTool(maintain_focus=False)
    plot.add_tools(
        PanTool(),
        BoxZoomTool(),
        wheelzoomtool,
        ResetTool(),
        hovertool,
        boxedittool,
    )
    plot.toolbar.active_scroll = wheelzoomtool

    # shared frame ranges
    frame_range = Range1d(0, 1, bounds=(0, 1))
    scanning_motor_range = Range1d(0, 1, bounds=(0, 1))

    det_x_range = Range1d(0, IMAGE_W, bounds=(0, IMAGE_W))
    overview_plot_x = Plot(
        title=Title(text="Projections on X-axis"),
        x_range=det_x_range,
        y_range=frame_range,
        extra_y_ranges={"scanning_motor": scanning_motor_range},
        plot_height=400,
        plot_width=IMAGE_PLOT_W - 3,
    )

    # ---- tools
    wheelzoomtool = WheelZoomTool(maintain_focus=False)
    overview_plot_x.toolbar.logo = None
    overview_plot_x.add_tools(
        PanTool(),
        BoxZoomTool(),
        wheelzoomtool,
        ResetTool(),
    )
    overview_plot_x.toolbar.active_scroll = wheelzoomtool

    # ---- axes
    overview_plot_x.add_layout(LinearAxis(axis_label="Coordinate X, pix"),
                               place="below")
    overview_plot_x.add_layout(LinearAxis(axis_label="Frame",
                                          major_label_orientation="vertical"),
                               place="left")

    # ---- grid lines
    overview_plot_x.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    overview_plot_x.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- rgba image glyph
    overview_plot_x_image_source = ColumnDataSource(
        dict(image=[np.zeros((1, 1), dtype="float32")],
             x=[0],
             y=[0],
             dw=[IMAGE_W],
             dh=[1]))

    overview_plot_x_image_glyph = Image(image="image",
                                        x="x",
                                        y="y",
                                        dw="dw",
                                        dh="dh")
    overview_plot_x.add_glyph(overview_plot_x_image_source,
                              overview_plot_x_image_glyph,
                              name="image_glyph")

    det_y_range = Range1d(0, IMAGE_H, bounds=(0, IMAGE_H))
    overview_plot_y = Plot(
        title=Title(text="Projections on Y-axis"),
        x_range=det_y_range,
        y_range=frame_range,
        extra_y_ranges={"scanning_motor": scanning_motor_range},
        plot_height=400,
        plot_width=IMAGE_PLOT_H + 22,
    )

    # ---- tools
    wheelzoomtool = WheelZoomTool(maintain_focus=False)
    overview_plot_y.toolbar.logo = None
    overview_plot_y.add_tools(
        PanTool(),
        BoxZoomTool(),
        wheelzoomtool,
        ResetTool(),
    )
    overview_plot_y.toolbar.active_scroll = wheelzoomtool

    # ---- axes
    overview_plot_y.add_layout(LinearAxis(axis_label="Coordinate Y, pix"),
                               place="below")
    overview_plot_y.add_layout(
        LinearAxis(
            y_range_name="scanning_motor",
            axis_label="Scanning motor",
            major_label_orientation="vertical",
        ),
        place="right",
    )

    # ---- grid lines
    overview_plot_y.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    overview_plot_y.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- rgba image glyph
    overview_plot_y_image_source = ColumnDataSource(
        dict(image=[np.zeros((1, 1), dtype="float32")],
             x=[0],
             y=[0],
             dw=[IMAGE_H],
             dh=[1]))

    overview_plot_y_image_glyph = Image(image="image",
                                        x="x",
                                        y="y",
                                        dw="dw",
                                        dh="dh")
    overview_plot_y.add_glyph(overview_plot_y_image_source,
                              overview_plot_y_image_glyph,
                              name="image_glyph")

    roi_avg_plot = Plot(
        x_range=DataRange1d(),
        y_range=DataRange1d(),
        plot_height=150,
        plot_width=IMAGE_PLOT_W,
        toolbar_location="left",
    )

    # ---- tools
    roi_avg_plot.toolbar.logo = None

    # ---- axes
    roi_avg_plot.add_layout(LinearAxis(), place="below")
    roi_avg_plot.add_layout(LinearAxis(major_label_orientation="vertical"),
                            place="left")

    # ---- grid lines
    roi_avg_plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    roi_avg_plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    roi_avg_plot_line_source = ColumnDataSource(dict(x=[], y=[]))
    roi_avg_plot.add_glyph(roi_avg_plot_line_source,
                           Line(x="x", y="y", line_color="steelblue"))

    cmap_dict = {
        "gray": Greys256,
        "gray_reversed": Greys256[::-1],
        "plasma": Plasma256,
        "cividis": Cividis256,
    }

    def colormap_callback(_attr, _old, new):
        image_glyph.color_mapper = LinearColorMapper(palette=cmap_dict[new])
        overview_plot_x_image_glyph.color_mapper = LinearColorMapper(
            palette=cmap_dict[new])
        overview_plot_y_image_glyph.color_mapper = LinearColorMapper(
            palette=cmap_dict[new])

    colormap = Select(title="Colormap:",
                      options=list(cmap_dict.keys()),
                      width=210)
    colormap.on_change("value", colormap_callback)
    colormap.value = "plasma"

    STEP = 1

    def main_auto_checkbox_callback(state):
        if state:
            display_min_spinner.disabled = True
            display_max_spinner.disabled = True
        else:
            display_min_spinner.disabled = False
            display_max_spinner.disabled = False

        update_image()

    main_auto_checkbox = CheckboxGroup(labels=["Main Auto Range"],
                                       active=[0],
                                       width=145,
                                       margin=[10, 5, 0, 5])
    main_auto_checkbox.on_click(main_auto_checkbox_callback)

    def display_max_spinner_callback(_attr, _old_value, new_value):
        display_min_spinner.high = new_value - STEP
        image_glyph.color_mapper.high = new_value

    display_max_spinner = Spinner(
        low=0 + STEP,
        value=1,
        step=STEP,
        disabled=bool(main_auto_checkbox.active),
        width=100,
        height=31,
    )
    display_max_spinner.on_change("value", display_max_spinner_callback)

    def display_min_spinner_callback(_attr, _old_value, new_value):
        display_max_spinner.low = new_value + STEP
        image_glyph.color_mapper.low = new_value

    display_min_spinner = Spinner(
        low=0,
        high=1 - STEP,
        value=0,
        step=STEP,
        disabled=bool(main_auto_checkbox.active),
        width=100,
        height=31,
    )
    display_min_spinner.on_change("value", display_min_spinner_callback)

    PROJ_STEP = 0.1

    def proj_auto_checkbox_callback(state):
        if state:
            proj_display_min_spinner.disabled = True
            proj_display_max_spinner.disabled = True
        else:
            proj_display_min_spinner.disabled = False
            proj_display_max_spinner.disabled = False

        update_overview_plot()

    proj_auto_checkbox = CheckboxGroup(labels=["Projections Auto Range"],
                                       active=[0],
                                       width=145,
                                       margin=[10, 5, 0, 5])
    proj_auto_checkbox.on_click(proj_auto_checkbox_callback)

    def proj_display_max_spinner_callback(_attr, _old_value, new_value):
        proj_display_min_spinner.high = new_value - PROJ_STEP
        overview_plot_x_image_glyph.color_mapper.high = new_value
        overview_plot_y_image_glyph.color_mapper.high = new_value

    proj_display_max_spinner = Spinner(
        low=0 + PROJ_STEP,
        value=1,
        step=PROJ_STEP,
        disabled=bool(proj_auto_checkbox.active),
        width=100,
        height=31,
    )
    proj_display_max_spinner.on_change("value",
                                       proj_display_max_spinner_callback)

    def proj_display_min_spinner_callback(_attr, _old_value, new_value):
        proj_display_max_spinner.low = new_value + PROJ_STEP
        overview_plot_x_image_glyph.color_mapper.low = new_value
        overview_plot_y_image_glyph.color_mapper.low = new_value

    proj_display_min_spinner = Spinner(
        low=0,
        high=1 - PROJ_STEP,
        value=0,
        step=PROJ_STEP,
        disabled=bool(proj_auto_checkbox.active),
        width=100,
        height=31,
    )
    proj_display_min_spinner.on_change("value",
                                       proj_display_min_spinner_callback)

    def hkl_button_callback():
        index = index_spinner.value
        h, k, l = calculate_hkl(det_data, index)
        image_source.data.update(h=[h], k=[k], l=[l])

    hkl_button = Button(label="Calculate hkl (slow)", width=210)
    hkl_button.on_click(hkl_button_callback)

    def events_list_callback(_attr, _old, new):
        doc.events_list_spind.value = new

    events_list = TextAreaInput(rows=7, width=830)
    events_list.on_change("value", events_list_callback)
    doc.events_list_hdf_viewer = events_list

    def add_event_button_callback():
        diff_vec = []
        p0 = [1.0, 0.0, 1.0]
        maxfev = 100000

        wave = det_data["wave"]
        ddist = det_data["ddist"]

        gamma = det_data["gamma"][0]
        omega = det_data["omega"][0]
        nu = det_data["nu"][0]
        chi = det_data["chi"][0]
        phi = det_data["phi"][0]

        scan_motor = det_data["scan_motor"]
        var_angle = det_data[scan_motor]

        x0 = int(np.floor(det_x_range.start))
        xN = int(np.ceil(det_x_range.end))
        y0 = int(np.floor(det_y_range.start))
        yN = int(np.ceil(det_y_range.end))
        fr0 = int(np.floor(frame_range.start))
        frN = int(np.ceil(frame_range.end))
        data_roi = det_data["data"][fr0:frN, y0:yN, x0:xN]

        cnts = np.sum(data_roi, axis=(1, 2))
        coeff, _ = curve_fit(gauss,
                             range(len(cnts)),
                             cnts,
                             p0=p0,
                             maxfev=maxfev)

        m = cnts.mean()
        sd = cnts.std()
        snr_cnts = np.where(sd == 0, 0, m / sd)

        frC = fr0 + coeff[1]
        var_F = var_angle[math.floor(frC)]
        var_C = var_angle[math.ceil(frC)]
        frStep = frC - math.floor(frC)
        var_step = var_C - var_F
        var_p = var_F + var_step * frStep

        if scan_motor == "gamma":
            gamma = var_p
        elif scan_motor == "omega":
            omega = var_p
        elif scan_motor == "nu":
            nu = var_p
        elif scan_motor == "chi":
            chi = var_p
        elif scan_motor == "phi":
            phi = var_p

        intensity = coeff[1] * abs(
            coeff[2] * var_step) * math.sqrt(2) * math.sqrt(np.pi)

        projX = np.sum(data_roi, axis=(0, 1))
        coeff, _ = curve_fit(gauss,
                             range(len(projX)),
                             projX,
                             p0=p0,
                             maxfev=maxfev)
        x_pos = x0 + coeff[1]

        projY = np.sum(data_roi, axis=(0, 2))
        coeff, _ = curve_fit(gauss,
                             range(len(projY)),
                             projY,
                             p0=p0,
                             maxfev=maxfev)
        y_pos = y0 + coeff[1]

        ga, nu = pyzebra.det2pol(ddist, gamma, nu, x_pos, y_pos)
        diff_vector = pyzebra.z1frmd(wave, ga, omega, chi, phi, nu)
        d_spacing = float(pyzebra.dandth(wave, diff_vector)[0])
        diff_vector = diff_vector.flatten() * 1e10
        dv1, dv2, dv3 = diff_vector

        diff_vec.append(diff_vector)

        if events_list.value and not events_list.value.endswith("\n"):
            events_list.value = events_list.value + "\n"

        events_list.value = (
            events_list.value +
            f"{x_pos} {y_pos} {intensity} {snr_cnts} {dv1} {dv2} {dv3} {d_spacing}"
        )

    add_event_button = Button(label="Add spind event")
    add_event_button.on_click(add_event_button_callback)

    metadata_table_source = ColumnDataSource(
        dict(geom=[""], temp=[None], mf=[None]))
    num_formatter = NumberFormatter(format="0.00", nan_format="")
    metadata_table = DataTable(
        source=metadata_table_source,
        columns=[
            TableColumn(field="geom", title="Geometry", width=100),
            TableColumn(field="temp",
                        title="Temperature",
                        formatter=num_formatter,
                        width=100),
            TableColumn(field="mf",
                        title="Magnetic Field",
                        formatter=num_formatter,
                        width=100),
        ],
        width=300,
        height=50,
        autosize_mode="none",
        index_position=None,
    )

    # Final layout
    import_layout = column(proposal_textinput, upload_div, upload_button,
                           file_select)
    layout_image = column(
        gridplot([[proj_v, None], [plot, proj_h]], merge_tools=False))
    colormap_layout = column(
        colormap,
        main_auto_checkbox,
        row(display_min_spinner, display_max_spinner),
        proj_auto_checkbox,
        row(proj_display_min_spinner, proj_display_max_spinner),
    )

    layout_controls = column(
        row(metadata_table, index_spinner,
            column(Spacer(height=25), index_slider)),
        row(add_event_button, hkl_button),
        row(events_list),
    )

    layout_overview = column(
        gridplot(
            [[overview_plot_x, overview_plot_y]],
            toolbar_options=dict(logo=None),
            merge_tools=True,
            toolbar_location="left",
        ), )

    tab_layout = row(
        column(import_layout, colormap_layout),
        column(layout_overview, layout_controls),
        column(roi_avg_plot, layout_image),
    )

    return Panel(child=tab_layout, title="hdf viewer")
Пример #11
0
def create():
    det_data = {}
    roi_selection = {}

    upload_div = Div(text="Open .cami file:")

    def upload_button_callback(_attr, _old, new):
        with io.StringIO(base64.b64decode(new).decode()) as file:
            h5meta_list = pyzebra.parse_h5meta(file)
            file_list = h5meta_list["filelist"]
            filelist.options = file_list
            filelist.value = file_list[0]

    upload_button = FileInput(accept=".cami")
    upload_button.on_change("value", upload_button_callback)

    def update_image(index=None):
        if index is None:
            index = index_spinner.value

        current_image = det_data["data"][index]
        proj_v_line_source.data.update(x=np.arange(0, IMAGE_W) + 0.5,
                                       y=np.mean(current_image, axis=0))
        proj_h_line_source.data.update(x=np.mean(current_image, axis=1),
                                       y=np.arange(0, IMAGE_H) + 0.5)

        image_source.data.update(
            h=[np.zeros((1, 1))],
            k=[np.zeros((1, 1))],
            l=[np.zeros((1, 1))],
        )
        image_source.data.update(image=[current_image])

        if auto_toggle.active:
            im_max = int(np.max(current_image))
            im_min = int(np.min(current_image))

            display_min_spinner.value = im_min
            display_max_spinner.value = im_max

            image_glyph.color_mapper.low = im_min
            image_glyph.color_mapper.high = im_max

    def update_overview_plot():
        h5_data = det_data["data"]
        n_im, n_y, n_x = h5_data.shape
        overview_x = np.mean(h5_data, axis=1)
        overview_y = np.mean(h5_data, axis=2)

        overview_plot_x_image_source.data.update(image=[overview_x], dw=[n_x])
        overview_plot_y_image_source.data.update(image=[overview_y], dw=[n_y])

        if frame_button_group.active == 0:  # Frame
            overview_plot_x.axis[1].axis_label = "Frame"
            overview_plot_y.axis[1].axis_label = "Frame"

            overview_plot_x_image_source.data.update(y=[0], dh=[n_im])
            overview_plot_y_image_source.data.update(y=[0], dh=[n_im])

        elif frame_button_group.active == 1:  # Omega
            overview_plot_x.axis[1].axis_label = "Omega"
            overview_plot_y.axis[1].axis_label = "Omega"

            om = det_data["rot_angle"]
            om_start = om[0]
            om_end = (om[-1] - om[0]) * n_im / (n_im - 1)
            overview_plot_x_image_source.data.update(y=[om_start], dh=[om_end])
            overview_plot_y_image_source.data.update(y=[om_start], dh=[om_end])

    def filelist_callback(_attr, _old, new):
        nonlocal det_data
        det_data = pyzebra.read_detector_data(new)

        index_spinner.value = 0
        index_spinner.high = det_data["data"].shape[0] - 1
        update_image(0)
        update_overview_plot()

    filelist = Select()
    filelist.on_change("value", filelist_callback)

    def index_spinner_callback(_attr, _old, new):
        update_image(new)

    index_spinner = Spinner(title="Image index:", value=0, low=0)
    index_spinner.on_change("value", index_spinner_callback)

    plot = Plot(
        x_range=Range1d(0, IMAGE_W, bounds=(0, IMAGE_W)),
        y_range=Range1d(0, IMAGE_H, bounds=(0, IMAGE_H)),
        plot_height=IMAGE_H * 3,
        plot_width=IMAGE_W * 3,
        toolbar_location="left",
    )

    # ---- tools
    plot.toolbar.logo = None

    # ---- axes
    plot.add_layout(LinearAxis(), place="above")
    plot.add_layout(LinearAxis(major_label_orientation="vertical"),
                    place="right")

    # ---- grid lines
    plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- rgba image glyph
    image_source = ColumnDataSource(
        dict(
            image=[np.zeros((IMAGE_H, IMAGE_W), dtype="float32")],
            h=[np.zeros((1, 1))],
            k=[np.zeros((1, 1))],
            l=[np.zeros((1, 1))],
            x=[0],
            y=[0],
            dw=[IMAGE_W],
            dh=[IMAGE_H],
        ))

    h_glyph = Image(image="h", x="x", y="y", dw="dw", dh="dh", global_alpha=0)
    k_glyph = Image(image="k", x="x", y="y", dw="dw", dh="dh", global_alpha=0)
    l_glyph = Image(image="l", x="x", y="y", dw="dw", dh="dh", global_alpha=0)

    plot.add_glyph(image_source, h_glyph)
    plot.add_glyph(image_source, k_glyph)
    plot.add_glyph(image_source, l_glyph)

    image_glyph = Image(image="image", x="x", y="y", dw="dw", dh="dh")
    plot.add_glyph(image_source, image_glyph, name="image_glyph")

    # ---- projections
    proj_v = Plot(
        x_range=plot.x_range,
        y_range=DataRange1d(),
        plot_height=200,
        plot_width=IMAGE_W * 3,
        toolbar_location=None,
    )

    proj_v.add_layout(LinearAxis(major_label_orientation="vertical"),
                      place="right")
    proj_v.add_layout(LinearAxis(major_label_text_font_size="0pt"),
                      place="below")

    proj_v.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    proj_v.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    proj_v_line_source = ColumnDataSource(dict(x=[], y=[]))
    proj_v.add_glyph(proj_v_line_source,
                     Line(x="x", y="y", line_color="steelblue"))

    proj_h = Plot(
        x_range=DataRange1d(),
        y_range=plot.y_range,
        plot_height=IMAGE_H * 3,
        plot_width=200,
        toolbar_location=None,
    )

    proj_h.add_layout(LinearAxis(), place="above")
    proj_h.add_layout(LinearAxis(major_label_text_font_size="0pt"),
                      place="left")

    proj_h.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    proj_h.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    proj_h_line_source = ColumnDataSource(dict(x=[], y=[]))
    proj_h.add_glyph(proj_h_line_source,
                     Line(x="x", y="y", line_color="steelblue"))

    # add tools
    hovertool = HoverTool(tooltips=[("intensity",
                                     "@image"), ("h", "@h"), ("k",
                                                              "@k"), ("l",
                                                                      "@l")])

    box_edit_source = ColumnDataSource(dict(x=[], y=[], width=[], height=[]))
    box_edit_glyph = Rect(x="x",
                          y="y",
                          width="width",
                          height="height",
                          fill_alpha=0,
                          line_color="red")
    box_edit_renderer = plot.add_glyph(box_edit_source, box_edit_glyph)
    boxedittool = BoxEditTool(renderers=[box_edit_renderer], num_objects=1)

    def box_edit_callback(_attr, _old, new):
        if new["x"]:
            h5_data = det_data["data"]
            x_val = np.arange(h5_data.shape[0])
            left = int(np.floor(new["x"][0]))
            right = int(np.ceil(new["x"][0] + new["width"][0]))
            bottom = int(np.floor(new["y"][0]))
            top = int(np.ceil(new["y"][0] + new["height"][0]))
            y_val = np.sum(h5_data[:, bottom:top, left:right], axis=(1, 2))
        else:
            x_val = []
            y_val = []

        roi_avg_plot_line_source.data.update(x=x_val, y=y_val)

    box_edit_source.on_change("data", box_edit_callback)

    wheelzoomtool = WheelZoomTool(maintain_focus=False)
    plot.add_tools(
        PanTool(),
        BoxZoomTool(),
        wheelzoomtool,
        ResetTool(),
        hovertool,
        boxedittool,
    )
    plot.toolbar.active_scroll = wheelzoomtool

    # shared frame range
    frame_range = DataRange1d()
    det_x_range = DataRange1d()
    overview_plot_x = Plot(
        title=Title(text="Projections on X-axis"),
        x_range=det_x_range,
        y_range=frame_range,
        plot_height=400,
        plot_width=400,
        toolbar_location="left",
    )

    # ---- tools
    wheelzoomtool = WheelZoomTool(maintain_focus=False)
    overview_plot_x.toolbar.logo = None
    overview_plot_x.add_tools(
        PanTool(),
        BoxZoomTool(),
        wheelzoomtool,
        ResetTool(),
    )
    overview_plot_x.toolbar.active_scroll = wheelzoomtool

    # ---- axes
    overview_plot_x.add_layout(LinearAxis(axis_label="Coordinate X, pix"),
                               place="below")
    overview_plot_x.add_layout(LinearAxis(axis_label="Frame",
                                          major_label_orientation="vertical"),
                               place="left")

    # ---- grid lines
    overview_plot_x.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    overview_plot_x.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- rgba image glyph
    overview_plot_x_image_source = ColumnDataSource(
        dict(image=[np.zeros((1, 1), dtype="float32")],
             x=[0],
             y=[0],
             dw=[1],
             dh=[1]))

    overview_plot_x_image_glyph = Image(image="image",
                                        x="x",
                                        y="y",
                                        dw="dw",
                                        dh="dh")
    overview_plot_x.add_glyph(overview_plot_x_image_source,
                              overview_plot_x_image_glyph,
                              name="image_glyph")

    det_y_range = DataRange1d()
    overview_plot_y = Plot(
        title=Title(text="Projections on Y-axis"),
        x_range=det_y_range,
        y_range=frame_range,
        plot_height=400,
        plot_width=400,
        toolbar_location="left",
    )

    # ---- tools
    wheelzoomtool = WheelZoomTool(maintain_focus=False)
    overview_plot_y.toolbar.logo = None
    overview_plot_y.add_tools(
        PanTool(),
        BoxZoomTool(),
        wheelzoomtool,
        ResetTool(),
    )
    overview_plot_y.toolbar.active_scroll = wheelzoomtool

    # ---- axes
    overview_plot_y.add_layout(LinearAxis(axis_label="Coordinate Y, pix"),
                               place="below")
    overview_plot_y.add_layout(LinearAxis(axis_label="Frame",
                                          major_label_orientation="vertical"),
                               place="left")

    # ---- grid lines
    overview_plot_y.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    overview_plot_y.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- rgba image glyph
    overview_plot_y_image_source = ColumnDataSource(
        dict(image=[np.zeros((1, 1), dtype="float32")],
             x=[0],
             y=[0],
             dw=[1],
             dh=[1]))

    overview_plot_y_image_glyph = Image(image="image",
                                        x="x",
                                        y="y",
                                        dw="dw",
                                        dh="dh")
    overview_plot_y.add_glyph(overview_plot_y_image_source,
                              overview_plot_y_image_glyph,
                              name="image_glyph")

    def frame_button_group_callback(_active):
        update_overview_plot()

    frame_button_group = RadioButtonGroup(labels=["Frames", "Omega"], active=0)
    frame_button_group.on_click(frame_button_group_callback)

    roi_avg_plot = Plot(
        x_range=DataRange1d(),
        y_range=DataRange1d(),
        plot_height=IMAGE_H * 3,
        plot_width=IMAGE_W * 3,
        toolbar_location="left",
    )

    # ---- tools
    roi_avg_plot.toolbar.logo = None

    # ---- axes
    roi_avg_plot.add_layout(LinearAxis(), place="below")
    roi_avg_plot.add_layout(LinearAxis(major_label_orientation="vertical"),
                            place="left")

    # ---- grid lines
    roi_avg_plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    roi_avg_plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    roi_avg_plot_line_source = ColumnDataSource(dict(x=[], y=[]))
    roi_avg_plot.add_glyph(roi_avg_plot_line_source,
                           Line(x="x", y="y", line_color="steelblue"))

    cmap_dict = {
        "gray": Greys256,
        "gray_reversed": Greys256[::-1],
        "plasma": Plasma256,
        "cividis": Cividis256,
    }

    def colormap_callback(_attr, _old, new):
        image_glyph.color_mapper = LinearColorMapper(palette=cmap_dict[new])
        overview_plot_x_image_glyph.color_mapper = LinearColorMapper(
            palette=cmap_dict[new])
        overview_plot_y_image_glyph.color_mapper = LinearColorMapper(
            palette=cmap_dict[new])

    colormap = Select(title="Colormap:", options=list(cmap_dict.keys()))
    colormap.on_change("value", colormap_callback)
    colormap.value = "plasma"

    radio_button_group = RadioButtonGroup(labels=["nb", "nb_bi"], active=0)

    STEP = 1

    # ---- colormap auto toggle button
    def auto_toggle_callback(state):
        if state:
            display_min_spinner.disabled = True
            display_max_spinner.disabled = True
        else:
            display_min_spinner.disabled = False
            display_max_spinner.disabled = False

        update_image()

    auto_toggle = Toggle(label="Auto Range",
                         active=True,
                         button_type="default")
    auto_toggle.on_click(auto_toggle_callback)

    # ---- colormap display max value
    def display_max_spinner_callback(_attr, _old_value, new_value):
        display_min_spinner.high = new_value - STEP
        image_glyph.color_mapper.high = new_value

    display_max_spinner = Spinner(
        title="Maximal Display Value:",
        low=0 + STEP,
        value=1,
        step=STEP,
        disabled=auto_toggle.active,
    )
    display_max_spinner.on_change("value", display_max_spinner_callback)

    # ---- colormap display min value
    def display_min_spinner_callback(_attr, _old_value, new_value):
        display_max_spinner.low = new_value + STEP
        image_glyph.color_mapper.low = new_value

    display_min_spinner = Spinner(
        title="Minimal Display Value:",
        high=1 - STEP,
        value=0,
        step=STEP,
        disabled=auto_toggle.active,
    )
    display_min_spinner.on_change("value", display_min_spinner_callback)

    def hkl_button_callback():
        index = index_spinner.value
        setup_type = "nb_bi" if radio_button_group.active else "nb"
        h, k, l = calculate_hkl(det_data, index, setup_type)
        image_source.data.update(h=[h], k=[k], l=[l])

    hkl_button = Button(label="Calculate hkl (slow)")
    hkl_button.on_click(hkl_button_callback)

    selection_list = TextAreaInput(rows=7)

    def selection_button_callback():
        nonlocal roi_selection
        selection = [
            int(np.floor(det_x_range.start)),
            int(np.ceil(det_x_range.end)),
            int(np.floor(det_y_range.start)),
            int(np.ceil(det_y_range.end)),
            int(np.floor(frame_range.start)),
            int(np.ceil(frame_range.end)),
        ]

        filename_id = filelist.value[-8:-4]
        if filename_id in roi_selection:
            roi_selection[f"{filename_id}"].append(selection)
        else:
            roi_selection[f"{filename_id}"] = [selection]

        selection_list.value = str(roi_selection)

    selection_button = Button(label="Add selection")
    selection_button.on_click(selection_button_callback)

    # Final layout
    layout_image = column(
        gridplot([[proj_v, None], [plot, proj_h]], merge_tools=False),
        row(index_spinner))
    colormap_layout = column(colormap, auto_toggle, display_max_spinner,
                             display_min_spinner)
    hkl_layout = column(radio_button_group, hkl_button)

    layout_overview = column(
        gridplot(
            [[overview_plot_x, overview_plot_y]],
            toolbar_options=dict(logo=None),
            merge_tools=True,
        ),
        frame_button_group,
    )

    tab_layout = row(
        column(
            upload_div,
            upload_button,
            filelist,
            layout_image,
            row(colormap_layout, hkl_layout),
        ),
        column(
            roi_avg_plot,
            layout_overview,
            row(selection_button, selection_list),
        ),
    )

    return Panel(child=tab_layout, title="Data Viewer")
from bokeh.io import output_file, show
from bokeh.models import FileInput

output_file("file_input.html")

file_input = FileInput()

show(file_input)
Пример #13
0
def create():
    config = pyzebra.AnatricConfig()

    def _load_config_file(file):
        config.load_from_file(file)

        logfile_textinput.value = config.logfile
        logfile_verbosity_select.value = config.logfile_verbosity

        filelist_type.value = config.filelist_type
        filelist_format_textinput.value = config.filelist_format
        filelist_datapath_textinput.value = config.filelist_datapath
        filelist_ranges_textareainput.value = "\n".join(
            map(str, config.filelist_ranges))

        crystal_sample_textinput.value = config.crystal_sample
        lambda_textinput.value = config.crystal_lambda
        zeroOM_textinput.value = config.crystal_zeroOM
        zeroSTT_textinput.value = config.crystal_zeroSTT
        zeroCHI_textinput.value = config.crystal_zeroCHI
        ub_textareainput.value = config.crystal_UB

        dataFactory_implementation_select.value = config.dataFactory_implementation
        dataFactory_dist1_textinput.value = config.dataFactory_dist1
        reflectionPrinter_format_select.value = config.reflectionPrinter_format

        set_active_widgets(config.algorithm)
        if config.algorithm == "adaptivemaxcog":
            threshold_textinput.value = config.threshold
            shell_textinput.value = config.shell
            steepness_textinput.value = config.steepness
            duplicateDistance_textinput.value = config.duplicateDistance
            maxequal_textinput.value = config.maxequal
            aps_window_textinput.value = str(
                tuple(map(int, config.aps_window.values())))

        elif config.algorithm == "adaptivedynamic":
            adm_window_textinput.value = str(
                tuple(map(int, config.adm_window.values())))
            border_textinput.value = str(
                tuple(map(int, config.border.values())))
            minWindow_textinput.value = str(
                tuple(map(int, config.minWindow.values())))
            reflectionFile_textinput.value = config.reflectionFile
            targetMonitor_textinput.value = config.targetMonitor
            smoothSize_textinput.value = config.smoothSize
            loop_textinput.value = config.loop
            minPeakCount_textinput.value = config.minPeakCount
            displacementCurve_textinput.value = "\n".join(
                map(str, config.displacementCurve))
        else:
            raise ValueError("Unknown processing mode.")

    def set_active_widgets(implementation):
        if implementation == "adaptivemaxcog":
            mode_radio_button_group.active = 0
            disable_adaptivemaxcog = False
            disable_adaptivedynamic = True

        elif implementation == "adaptivedynamic":
            mode_radio_button_group.active = 1
            disable_adaptivemaxcog = True
            disable_adaptivedynamic = False
        else:
            raise ValueError(
                "Implementation can be either 'adaptivemaxcog' or 'adaptivedynamic'"
            )

        threshold_textinput.disabled = disable_adaptivemaxcog
        shell_textinput.disabled = disable_adaptivemaxcog
        steepness_textinput.disabled = disable_adaptivemaxcog
        duplicateDistance_textinput.disabled = disable_adaptivemaxcog
        maxequal_textinput.disabled = disable_adaptivemaxcog
        aps_window_textinput.disabled = disable_adaptivemaxcog

        adm_window_textinput.disabled = disable_adaptivedynamic
        border_textinput.disabled = disable_adaptivedynamic
        minWindow_textinput.disabled = disable_adaptivedynamic
        reflectionFile_textinput.disabled = disable_adaptivedynamic
        targetMonitor_textinput.disabled = disable_adaptivedynamic
        smoothSize_textinput.disabled = disable_adaptivedynamic
        loop_textinput.disabled = disable_adaptivedynamic
        minPeakCount_textinput.disabled = disable_adaptivedynamic
        displacementCurve_textinput.disabled = disable_adaptivedynamic

    upload_div = Div(text="Open XML configuration file:")

    def upload_button_callback(_attr, _old, new):
        with io.BytesIO(base64.b64decode(new)) as file:
            _load_config_file(file)

    upload_button = FileInput(accept=".xml")
    upload_button.on_change("value", upload_button_callback)

    # General parameters
    # ---- logfile
    def logfile_textinput_callback(_attr, _old, new):
        config.logfile = new

    logfile_textinput = TextInput(title="Logfile:",
                                  value="logfile.log",
                                  width=520)
    logfile_textinput.on_change("value", logfile_textinput_callback)

    def logfile_verbosity_select_callback(_attr, _old, new):
        config.logfile_verbosity = new

    logfile_verbosity_select = Select(title="verbosity:",
                                      options=["0", "5", "10", "15", "30"],
                                      width=70)
    logfile_verbosity_select.on_change("value",
                                       logfile_verbosity_select_callback)

    # ---- FileList
    def filelist_type_callback(_attr, _old, new):
        config.filelist_type = new

    filelist_type = Select(title="File List:",
                           options=["TRICS", "SINQ"],
                           width=100)
    filelist_type.on_change("value", filelist_type_callback)

    def filelist_format_textinput_callback(_attr, _old, new):
        config.filelist_format = new

    filelist_format_textinput = TextInput(title="format:", width=490)
    filelist_format_textinput.on_change("value",
                                        filelist_format_textinput_callback)

    def filelist_datapath_textinput_callback(_attr, _old, new):
        config.filelist_datapath = new

    filelist_datapath_textinput = TextInput(title="datapath:")
    filelist_datapath_textinput.on_change(
        "value", filelist_datapath_textinput_callback)

    def filelist_ranges_textareainput_callback(_attr, _old, new):
        ranges = []
        for line in new.splitlines():
            ranges.append(re.findall(r"\b\d+\b", line))
        config.filelist_ranges = ranges

    filelist_ranges_textareainput = TextAreaInput(title="ranges:", height=100)
    filelist_ranges_textareainput.on_change(
        "value", filelist_ranges_textareainput_callback)

    # ---- crystal
    def crystal_sample_textinput_callback(_attr, _old, new):
        config.crystal_sample = new

    crystal_sample_textinput = TextInput(title="Sample Name:")
    crystal_sample_textinput.on_change("value",
                                       crystal_sample_textinput_callback)

    def lambda_textinput_callback(_attr, _old, new):
        config.crystal_lambda = new

    lambda_textinput = TextInput(title="lambda:", width=140)
    lambda_textinput.on_change("value", lambda_textinput_callback)

    def ub_textareainput_callback(_attr, _old, new):
        config.crystal_UB = new

    ub_textareainput = TextAreaInput(title="UB matrix:", height=100)
    ub_textareainput.on_change("value", ub_textareainput_callback)

    def zeroOM_textinput_callback(_attr, _old, new):
        config.crystal_zeroOM = new

    zeroOM_textinput = TextInput(title="zeroOM:", width=140)
    zeroOM_textinput.on_change("value", zeroOM_textinput_callback)

    def zeroSTT_textinput_callback(_attr, _old, new):
        config.crystal_zeroSTT = new

    zeroSTT_textinput = TextInput(title="zeroSTT:", width=140)
    zeroSTT_textinput.on_change("value", zeroSTT_textinput_callback)

    def zeroCHI_textinput_callback(_attr, _old, new):
        config.crystal_zeroCHI = new

    zeroCHI_textinput = TextInput(title="zeroCHI:", width=140)
    zeroCHI_textinput.on_change("value", zeroCHI_textinput_callback)

    # ---- DataFactory
    def dataFactory_implementation_select_callback(_attr, _old, new):
        config.dataFactory_implementation = new

    dataFactory_implementation_select = Select(
        title="DataFactory implementation:",
        options=DATA_FACTORY_IMPLEMENTATION,
        width=300,
    )
    dataFactory_implementation_select.on_change(
        "value", dataFactory_implementation_select_callback)

    def dataFactory_dist1_textinput_callback(_attr, _old, new):
        config.dataFactory_dist1 = new

    dataFactory_dist1_textinput = TextInput(title="dist1:", width=290)
    dataFactory_dist1_textinput.on_change(
        "value", dataFactory_dist1_textinput_callback)

    # ---- BackgroundProcessor

    # ---- DetectorEfficency

    # ---- ReflectionPrinter
    def reflectionPrinter_format_select_callback(_attr, _old, new):
        config.reflectionPrinter_format = new

    reflectionPrinter_format_select = Select(
        title="ReflectionPrinter format:",
        options=REFLECTION_PRINTER_FORMATS,
        width=300,
    )
    reflectionPrinter_format_select.on_change(
        "value", reflectionPrinter_format_select_callback)

    # Adaptive Peak Detection (adaptivemaxcog)
    # ---- threshold
    def threshold_textinput_callback(_attr, _old, new):
        config.threshold = new

    threshold_textinput = TextInput(title="Threshold:")
    threshold_textinput.on_change("value", threshold_textinput_callback)

    # ---- shell
    def shell_textinput_callback(_attr, _old, new):
        config.shell = new

    shell_textinput = TextInput(title="Shell:")
    shell_textinput.on_change("value", shell_textinput_callback)

    # ---- steepness
    def steepness_textinput_callback(_attr, _old, new):
        config.steepness = new

    steepness_textinput = TextInput(title="Steepness:")
    steepness_textinput.on_change("value", steepness_textinput_callback)

    # ---- duplicateDistance
    def duplicateDistance_textinput_callback(_attr, _old, new):
        config.duplicateDistance = new

    duplicateDistance_textinput = TextInput(title="Duplicate Distance:")
    duplicateDistance_textinput.on_change(
        "value", duplicateDistance_textinput_callback)

    # ---- maxequal
    def maxequal_textinput_callback(_attr, _old, new):
        config.maxequal = new

    maxequal_textinput = TextInput(title="Max Equal:")
    maxequal_textinput.on_change("value", maxequal_textinput_callback)

    # ---- window
    def aps_window_textinput_callback(_attr, _old, new):
        config.aps_window = dict(
            zip(("x", "y", "z"), re.findall(r"\b\d+\b", new)))

    aps_window_textinput = TextInput(title="Window (x, y, z):")
    aps_window_textinput.on_change("value", aps_window_textinput_callback)

    # Adaptive Dynamic Mask Integration (adaptivedynamic)
    # ---- window
    def adm_window_textinput_callback(_attr, _old, new):
        config.adm_window = dict(
            zip(("x", "y", "z"), re.findall(r"\b\d+\b", new)))

    adm_window_textinput = TextInput(title="Window (x, y, z):")
    adm_window_textinput.on_change("value", adm_window_textinput_callback)

    # ---- border
    def border_textinput_callback(_attr, _old, new):
        config.border = dict(zip(("x", "y", "z"), re.findall(r"\b\d+\b", new)))

    border_textinput = TextInput(title="Border (x, y, z):")
    border_textinput.on_change("value", border_textinput_callback)

    # ---- minWindow
    def minWindow_textinput_callback(_attr, _old, new):
        config.minWindow = dict(
            zip(("x", "y", "z"), re.findall(r"\b\d+\b", new)))

    minWindow_textinput = TextInput(title="Min Window (x, y, z):")
    minWindow_textinput.on_change("value", minWindow_textinput_callback)

    # ---- reflectionFile
    def reflectionFile_textinput_callback(_attr, _old, new):
        config.reflectionFile = new

    reflectionFile_textinput = TextInput(title="Reflection File:")
    reflectionFile_textinput.on_change("value",
                                       reflectionFile_textinput_callback)

    # ---- targetMonitor
    def targetMonitor_textinput_callback(_attr, _old, new):
        config.targetMonitor = new

    targetMonitor_textinput = TextInput(title="Target Monitor:")
    targetMonitor_textinput.on_change("value",
                                      targetMonitor_textinput_callback)

    # ---- smoothSize
    def smoothSize_textinput_callback(_attr, _old, new):
        config.smoothSize = new

    smoothSize_textinput = TextInput(title="Smooth Size:")
    smoothSize_textinput.on_change("value", smoothSize_textinput_callback)

    # ---- loop
    def loop_textinput_callback(_attr, _old, new):
        config.loop = new

    loop_textinput = TextInput(title="Loop:")
    loop_textinput.on_change("value", loop_textinput_callback)

    # ---- minPeakCount
    def minPeakCount_textinput_callback(_attr, _old, new):
        config.minPeakCount = new

    minPeakCount_textinput = TextInput(title="Min Peak Count:")
    minPeakCount_textinput.on_change("value", minPeakCount_textinput_callback)

    # ---- displacementCurve
    def displacementCurve_textinput_callback(_attr, _old, new):
        maps = []
        for line in new.splitlines():
            maps.append(re.findall(r"\d+(?:\.\d+)?", line))
        config.displacementCurve = maps

    displacementCurve_textinput = TextAreaInput(
        title="Displacement Curve (twotheta, x, y):", height=100)
    displacementCurve_textinput.on_change(
        "value", displacementCurve_textinput_callback)

    def mode_radio_button_group_callback(active):
        if active == 0:
            config.algorithm = "adaptivemaxcog"
            set_active_widgets("adaptivemaxcog")
        else:
            config.algorithm = "adaptivedynamic"
            set_active_widgets("adaptivedynamic")

    mode_radio_button_group = RadioButtonGroup(
        labels=["Adaptive Peak Detection", "Adaptive Dynamic Integration"],
        active=0)
    mode_radio_button_group.on_click(mode_radio_button_group_callback)
    set_active_widgets("adaptivemaxcog")

    def process_button_callback():
        with tempfile.TemporaryDirectory() as temp_dir:
            temp_file = temp_dir + "/temp.xml"
            config.save_as(temp_file)
            pyzebra.anatric(temp_file)

            with open(config.logfile) as f_log:
                output_log.value = f_log.read()

    process_button = Button(label="Process", button_type="primary")
    process_button.on_click(process_button_callback)

    output_log = TextAreaInput(title="Logfile output:",
                               height=700,
                               disabled=True)
    output_config = TextAreaInput(title="Current config:",
                                  height=700,
                                  width=400,
                                  disabled=True)

    tab_layout = row(
        column(
            upload_div,
            upload_button,
            row(logfile_textinput, logfile_verbosity_select),
            row(filelist_type, filelist_format_textinput),
            filelist_datapath_textinput,
            filelist_ranges_textareainput,
            crystal_sample_textinput,
            row(lambda_textinput, zeroOM_textinput, zeroSTT_textinput,
                zeroCHI_textinput),
            ub_textareainput,
            row(dataFactory_implementation_select,
                dataFactory_dist1_textinput),
            reflectionPrinter_format_select,
            process_button,
        ),
        column(
            mode_radio_button_group,
            row(
                column(
                    threshold_textinput,
                    shell_textinput,
                    steepness_textinput,
                    duplicateDistance_textinput,
                    maxequal_textinput,
                    aps_window_textinput,
                ),
                column(
                    adm_window_textinput,
                    border_textinput,
                    minWindow_textinput,
                    reflectionFile_textinput,
                    targetMonitor_textinput,
                    smoothSize_textinput,
                    loop_textinput,
                    minPeakCount_textinput,
                    displacementCurve_textinput,
                ),
            ),
        ),
        output_config,
        output_log,
    )

    async def update_config():
        config.save_as("debug.xml")
        with open("debug.xml") as f_config:
            output_config.value = f_config.read()

    curdoc().add_periodic_callback(update_config, 1000)

    return Panel(child=tab_layout, title="Anatric")
Пример #14
0
def create():
    det_data = {}
    fit_params = {}
    js_data = ColumnDataSource(
        data=dict(content=["", ""], fname=["", ""], ext=["", ""]))

    def proposal_textinput_callback(_attr, _old, new):
        proposal = new.strip()
        for zebra_proposals_path in pyzebra.ZEBRA_PROPOSALS_PATHS:
            proposal_path = os.path.join(zebra_proposals_path, proposal)
            if os.path.isdir(proposal_path):
                # found it
                break
        else:
            raise ValueError(f"Can not find data for proposal '{proposal}'.")

        file_list = []
        for file in os.listdir(proposal_path):
            if file.endswith((".ccl", ".dat")):
                file_list.append((os.path.join(proposal_path, file), file))
        file_select.options = file_list
        file_open_button.disabled = False
        file_append_button.disabled = False

    proposal_textinput = TextInput(title="Proposal number:", width=210)
    proposal_textinput.on_change("value", proposal_textinput_callback)

    def _init_datatable():
        scan_list = [s["idx"] for s in det_data]
        hkl = [f'{s["h"]} {s["k"]} {s["l"]}' for s in det_data]
        export = [s.get("active", True) for s in det_data]
        scan_table_source.data.update(
            scan=scan_list,
            hkl=hkl,
            fit=[0] * len(scan_list),
            export=export,
        )
        scan_table_source.selected.indices = []
        scan_table_source.selected.indices = [0]

        merge_options = [(str(i), f"{i} ({idx})")
                         for i, idx in enumerate(scan_list)]
        merge_from_select.options = merge_options
        merge_from_select.value = merge_options[0][0]

    file_select = MultiSelect(title="Available .ccl/.dat files:",
                              width=210,
                              height=250)

    def file_open_button_callback():
        nonlocal det_data
        det_data = []
        for f_name in file_select.value:
            with open(f_name) as file:
                base, ext = os.path.splitext(f_name)
                if det_data:
                    append_data = pyzebra.parse_1D(file, ext)
                    pyzebra.normalize_dataset(append_data,
                                              monitor_spinner.value)
                    pyzebra.merge_datasets(det_data, append_data)
                else:
                    det_data = pyzebra.parse_1D(file, ext)
                    pyzebra.normalize_dataset(det_data, monitor_spinner.value)
                    pyzebra.merge_duplicates(det_data)
                    js_data.data.update(fname=[base, base])

        _init_datatable()
        append_upload_button.disabled = False

    file_open_button = Button(label="Open New", width=100, disabled=True)
    file_open_button.on_click(file_open_button_callback)

    def file_append_button_callback():
        for f_name in file_select.value:
            with open(f_name) as file:
                _, ext = os.path.splitext(f_name)
                append_data = pyzebra.parse_1D(file, ext)

            pyzebra.normalize_dataset(append_data, monitor_spinner.value)
            pyzebra.merge_datasets(det_data, append_data)

        _init_datatable()

    file_append_button = Button(label="Append", width=100, disabled=True)
    file_append_button.on_click(file_append_button_callback)

    def upload_button_callback(_attr, _old, new):
        nonlocal det_data
        det_data = []
        for f_str, f_name in zip(new, upload_button.filename):
            with io.StringIO(base64.b64decode(f_str).decode()) as file:
                base, ext = os.path.splitext(f_name)
                if det_data:
                    append_data = pyzebra.parse_1D(file, ext)
                    pyzebra.normalize_dataset(append_data,
                                              monitor_spinner.value)
                    pyzebra.merge_datasets(det_data, append_data)
                else:
                    det_data = pyzebra.parse_1D(file, ext)
                    pyzebra.normalize_dataset(det_data, monitor_spinner.value)
                    pyzebra.merge_duplicates(det_data)
                    js_data.data.update(fname=[base, base])

        _init_datatable()
        append_upload_button.disabled = False

    upload_div = Div(text="or upload new .ccl/.dat files:",
                     margin=(5, 5, 0, 5))
    upload_button = FileInput(accept=".ccl,.dat", multiple=True, width=200)
    upload_button.on_change("value", upload_button_callback)

    def append_upload_button_callback(_attr, _old, new):
        for f_str, f_name in zip(new, append_upload_button.filename):
            with io.StringIO(base64.b64decode(f_str).decode()) as file:
                _, ext = os.path.splitext(f_name)
                append_data = pyzebra.parse_1D(file, ext)

            pyzebra.normalize_dataset(append_data, monitor_spinner.value)
            pyzebra.merge_datasets(det_data, append_data)

        _init_datatable()

    append_upload_div = Div(text="append extra files:", margin=(5, 5, 0, 5))
    append_upload_button = FileInput(accept=".ccl,.dat",
                                     multiple=True,
                                     width=200,
                                     disabled=True)
    append_upload_button.on_change("value", append_upload_button_callback)

    def monitor_spinner_callback(_attr, old, new):
        if det_data:
            pyzebra.normalize_dataset(det_data, new)
            _update_plot(_get_selected_scan())

    monitor_spinner = Spinner(title="Monitor:",
                              mode="int",
                              value=100_000,
                              low=1,
                              width=145)
    monitor_spinner.on_change("value", monitor_spinner_callback)

    def _update_table():
        fit_ok = [(1 if "fit" in scan else 0) for scan in det_data]
        scan_table_source.data.update(fit=fit_ok)

    def _update_plot(scan):
        scan_motor = scan["scan_motor"]

        y = scan["counts"]
        x = scan[scan_motor]

        plot.axis[0].axis_label = scan_motor
        plot_scatter_source.data.update(x=x,
                                        y=y,
                                        y_upper=y + np.sqrt(y),
                                        y_lower=y - np.sqrt(y))

        fit = scan.get("fit")
        if fit is not None:
            x_fit = np.linspace(x[0], x[-1], 100)
            plot_fit_source.data.update(x=x_fit, y=fit.eval(x=x_fit))

            x_bkg = []
            y_bkg = []
            xs_peak = []
            ys_peak = []
            comps = fit.eval_components(x=x_fit)
            for i, model in enumerate(fit_params):
                if "linear" in model:
                    x_bkg = x_fit
                    y_bkg = comps[f"f{i}_"]

                elif any(val in model
                         for val in ("gaussian", "voigt", "pvoigt")):
                    xs_peak.append(x_fit)
                    ys_peak.append(comps[f"f{i}_"])

            plot_bkg_source.data.update(x=x_bkg, y=y_bkg)
            plot_peak_source.data.update(xs=xs_peak, ys=ys_peak)

            fit_output_textinput.value = fit.fit_report()

        else:
            plot_fit_source.data.update(x=[], y=[])
            plot_bkg_source.data.update(x=[], y=[])
            plot_peak_source.data.update(xs=[], ys=[])
            fit_output_textinput.value = ""

    # Main plot
    plot = Plot(
        x_range=DataRange1d(),
        y_range=DataRange1d(only_visible=True),
        plot_height=470,
        plot_width=700,
    )

    plot.add_layout(LinearAxis(axis_label="Counts"), place="left")
    plot.add_layout(LinearAxis(axis_label="Scan motor"), place="below")

    plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    plot_scatter_source = ColumnDataSource(
        dict(x=[0], y=[0], y_upper=[0], y_lower=[0]))
    plot_scatter = plot.add_glyph(
        plot_scatter_source, Scatter(x="x", y="y", line_color="steelblue"))
    plot.add_layout(
        Whisker(source=plot_scatter_source,
                base="x",
                upper="y_upper",
                lower="y_lower"))

    plot_fit_source = ColumnDataSource(dict(x=[0], y=[0]))
    plot_fit = plot.add_glyph(plot_fit_source, Line(x="x", y="y"))

    plot_bkg_source = ColumnDataSource(dict(x=[0], y=[0]))
    plot_bkg = plot.add_glyph(
        plot_bkg_source,
        Line(x="x", y="y", line_color="green", line_dash="dashed"))

    plot_peak_source = ColumnDataSource(dict(xs=[[0]], ys=[[0]]))
    plot_peak = plot.add_glyph(
        plot_peak_source,
        MultiLine(xs="xs", ys="ys", line_color="red", line_dash="dashed"))

    fit_from_span = Span(location=None, dimension="height", line_dash="dashed")
    plot.add_layout(fit_from_span)

    fit_to_span = Span(location=None, dimension="height", line_dash="dashed")
    plot.add_layout(fit_to_span)

    plot.add_layout(
        Legend(
            items=[
                ("data", [plot_scatter]),
                ("best fit", [plot_fit]),
                ("peak", [plot_peak]),
                ("linear", [plot_bkg]),
            ],
            location="top_left",
            click_policy="hide",
        ))

    plot.add_tools(PanTool(), WheelZoomTool(), ResetTool())
    plot.toolbar.logo = None

    # Scan select
    def scan_table_select_callback(_attr, old, new):
        if not new:
            # skip empty selections
            return

        # Avoid selection of multiple indicies (via Shift+Click or Ctrl+Click)
        if len(new) > 1:
            # drop selection to the previous one
            scan_table_source.selected.indices = old
            return

        if len(old) > 1:
            # skip unnecessary update caused by selection drop
            return

        _update_plot(det_data[new[0]])

    def scan_table_source_callback(_attr, _old, _new):
        _update_preview()

    scan_table_source = ColumnDataSource(
        dict(scan=[], hkl=[], fit=[], export=[]))
    scan_table_source.on_change("data", scan_table_source_callback)

    scan_table = DataTable(
        source=scan_table_source,
        columns=[
            TableColumn(field="scan", title="Scan", width=50),
            TableColumn(field="hkl", title="hkl", width=100),
            TableColumn(field="fit", title="Fit", width=50),
            TableColumn(field="export",
                        title="Export",
                        editor=CheckboxEditor(),
                        width=50),
        ],
        width=310,  # +60 because of the index column
        height=350,
        autosize_mode="none",
        editable=True,
    )

    scan_table_source.selected.on_change("indices", scan_table_select_callback)

    def _get_selected_scan():
        return det_data[scan_table_source.selected.indices[0]]

    merge_from_select = Select(title="scan:", width=145)

    def merge_button_callback():
        scan_into = _get_selected_scan()
        scan_from = det_data[int(merge_from_select.value)]

        if scan_into is scan_from:
            print("WARNING: Selected scans for merging are identical")
            return

        pyzebra.merge_scans(scan_into, scan_from)
        _update_plot(_get_selected_scan())

    merge_button = Button(label="Merge into current", width=145)
    merge_button.on_click(merge_button_callback)

    def restore_button_callback():
        pyzebra.restore_scan(_get_selected_scan())
        _update_plot(_get_selected_scan())

    restore_button = Button(label="Restore scan", width=145)
    restore_button.on_click(restore_button_callback)

    def fit_from_spinner_callback(_attr, _old, new):
        fit_from_span.location = new

    fit_from_spinner = Spinner(title="Fit from:", width=145)
    fit_from_spinner.on_change("value", fit_from_spinner_callback)

    def fit_to_spinner_callback(_attr, _old, new):
        fit_to_span.location = new

    fit_to_spinner = Spinner(title="to:", width=145)
    fit_to_spinner.on_change("value", fit_to_spinner_callback)

    def fitparams_add_dropdown_callback(click):
        # bokeh requires (str, str) for MultiSelect options
        new_tag = f"{click.item}-{fitparams_select.tags[0]}"
        fitparams_select.options.append((new_tag, click.item))
        fit_params[new_tag] = fitparams_factory(click.item)
        fitparams_select.tags[0] += 1

    fitparams_add_dropdown = Dropdown(
        label="Add fit function",
        menu=[
            ("Linear", "linear"),
            ("Gaussian", "gaussian"),
            ("Voigt", "voigt"),
            ("Pseudo Voigt", "pvoigt"),
            # ("Pseudo Voigt1", "pseudovoigt1"),
        ],
        width=145,
    )
    fitparams_add_dropdown.on_click(fitparams_add_dropdown_callback)

    def fitparams_select_callback(_attr, old, new):
        # Avoid selection of multiple indicies (via Shift+Click or Ctrl+Click)
        if len(new) > 1:
            # drop selection to the previous one
            fitparams_select.value = old
            return

        if len(old) > 1:
            # skip unnecessary update caused by selection drop
            return

        if new:
            fitparams_table_source.data.update(fit_params[new[0]])
        else:
            fitparams_table_source.data.update(
                dict(param=[], value=[], vary=[], min=[], max=[]))

    fitparams_select = MultiSelect(options=[], height=120, width=145)
    fitparams_select.tags = [0]
    fitparams_select.on_change("value", fitparams_select_callback)

    def fitparams_remove_button_callback():
        if fitparams_select.value:
            sel_tag = fitparams_select.value[0]
            del fit_params[sel_tag]
            for elem in fitparams_select.options:
                if elem[0] == sel_tag:
                    fitparams_select.options.remove(elem)
                    break

            fitparams_select.value = []

    fitparams_remove_button = Button(label="Remove fit function", width=145)
    fitparams_remove_button.on_click(fitparams_remove_button_callback)

    def fitparams_factory(function):
        if function == "linear":
            params = ["slope", "intercept"]
        elif function == "gaussian":
            params = ["amplitude", "center", "sigma"]
        elif function == "voigt":
            params = ["amplitude", "center", "sigma", "gamma"]
        elif function == "pvoigt":
            params = ["amplitude", "center", "sigma", "fraction"]
        elif function == "pseudovoigt1":
            params = ["amplitude", "center", "g_sigma", "l_sigma", "fraction"]
        else:
            raise ValueError("Unknown fit function")

        n = len(params)
        fitparams = dict(
            param=params,
            value=[None] * n,
            vary=[True] * n,
            min=[None] * n,
            max=[None] * n,
        )

        if function == "linear":
            fitparams["value"] = [0, 1]
            fitparams["vary"] = [False, True]
            fitparams["min"] = [None, 0]

        elif function == "gaussian":
            fitparams["min"] = [0, None, None]

        return fitparams

    fitparams_table_source = ColumnDataSource(
        dict(param=[], value=[], vary=[], min=[], max=[]))
    fitparams_table = DataTable(
        source=fitparams_table_source,
        columns=[
            TableColumn(field="param", title="Parameter"),
            TableColumn(field="value", title="Value", editor=NumberEditor()),
            TableColumn(field="vary", title="Vary", editor=CheckboxEditor()),
            TableColumn(field="min", title="Min", editor=NumberEditor()),
            TableColumn(field="max", title="Max", editor=NumberEditor()),
        ],
        height=200,
        width=350,
        index_position=None,
        editable=True,
        auto_edit=True,
    )

    # start with `background` and `gauss` fit functions added
    fitparams_add_dropdown_callback(types.SimpleNamespace(item="linear"))
    fitparams_add_dropdown_callback(types.SimpleNamespace(item="gaussian"))
    fitparams_select.value = ["gaussian-1"]  # add selection to gauss

    fit_output_textinput = TextAreaInput(title="Fit results:",
                                         width=750,
                                         height=200)

    def proc_all_button_callback():
        for scan, export in zip(det_data, scan_table_source.data["export"]):
            if export:
                pyzebra.fit_scan(scan,
                                 fit_params,
                                 fit_from=fit_from_spinner.value,
                                 fit_to=fit_to_spinner.value)
                pyzebra.get_area(
                    scan,
                    area_method=AREA_METHODS[area_method_radiobutton.active],
                    lorentz=lorentz_checkbox.active,
                )

        _update_plot(_get_selected_scan())
        _update_table()

    proc_all_button = Button(label="Process All",
                             button_type="primary",
                             width=145)
    proc_all_button.on_click(proc_all_button_callback)

    def proc_button_callback():
        scan = _get_selected_scan()
        pyzebra.fit_scan(scan,
                         fit_params,
                         fit_from=fit_from_spinner.value,
                         fit_to=fit_to_spinner.value)
        pyzebra.get_area(
            scan,
            area_method=AREA_METHODS[area_method_radiobutton.active],
            lorentz=lorentz_checkbox.active,
        )

        _update_plot(scan)
        _update_table()

    proc_button = Button(label="Process Current", width=145)
    proc_button.on_click(proc_button_callback)

    area_method_div = Div(text="Intensity:", margin=(5, 5, 0, 5))
    area_method_radiobutton = RadioGroup(labels=["Function", "Area"],
                                         active=0,
                                         width=145)

    lorentz_checkbox = CheckboxGroup(labels=["Lorentz Correction"],
                                     width=145,
                                     margin=(13, 5, 5, 5))

    export_preview_textinput = TextAreaInput(title="Export file preview:",
                                             width=500,
                                             height=400)

    def _update_preview():
        with tempfile.TemporaryDirectory() as temp_dir:
            temp_file = temp_dir + "/temp"
            export_data = []
            for s, export in zip(det_data, scan_table_source.data["export"]):
                if export:
                    export_data.append(s)

            pyzebra.export_1D(
                export_data,
                temp_file,
                export_target_select.value,
                hkl_precision=int(hkl_precision_select.value),
            )

            exported_content = ""
            file_content = []
            for ext in EXPORT_TARGETS[export_target_select.value]:
                fname = temp_file + ext
                if os.path.isfile(fname):
                    with open(fname) as f:
                        content = f.read()
                        exported_content += f"{ext} file:\n" + content
                else:
                    content = ""
                file_content.append(content)

            js_data.data.update(content=file_content)
            export_preview_textinput.value = exported_content

    def export_target_select_callback(_attr, _old, new):
        js_data.data.update(ext=EXPORT_TARGETS[new])
        _update_preview()

    export_target_select = Select(title="Export target:",
                                  options=list(EXPORT_TARGETS.keys()),
                                  value="fullprof",
                                  width=80)
    export_target_select.on_change("value", export_target_select_callback)
    js_data.data.update(ext=EXPORT_TARGETS[export_target_select.value])

    def hkl_precision_select_callback(_attr, _old, _new):
        _update_preview()

    hkl_precision_select = Select(title="hkl precision:",
                                  options=["2", "3", "4"],
                                  value="2",
                                  width=80)
    hkl_precision_select.on_change("value", hkl_precision_select_callback)

    save_button = Button(label="Download File(s)",
                         button_type="success",
                         width=200)
    save_button.js_on_click(
        CustomJS(args={"js_data": js_data}, code=javaScript))

    fitpeak_controls = row(
        column(fitparams_add_dropdown, fitparams_select,
               fitparams_remove_button),
        fitparams_table,
        Spacer(width=20),
        column(fit_from_spinner, lorentz_checkbox, area_method_div,
               area_method_radiobutton),
        column(fit_to_spinner, proc_button, proc_all_button),
    )

    scan_layout = column(
        scan_table,
        row(monitor_spinner, column(Spacer(height=19), restore_button)),
        row(column(Spacer(height=19), merge_button), merge_from_select),
    )

    import_layout = column(
        proposal_textinput,
        file_select,
        row(file_open_button, file_append_button),
        upload_div,
        upload_button,
        append_upload_div,
        append_upload_button,
    )

    export_layout = column(
        export_preview_textinput,
        row(export_target_select, hkl_precision_select,
            column(Spacer(height=19), row(save_button))),
    )

    tab_layout = column(
        row(import_layout, scan_layout, plot, Spacer(width=30), export_layout),
        row(fitpeak_controls, fit_output_textinput),
    )

    return Panel(child=tab_layout, title="ccl integrate")
Пример #15
0
def production_tab(beam, acc):
    '''Simulation

    '''
    pre = PreText(text='''Track beam in the accelerator.''')

    track_button = Button(label='Track', button_type="success", width=315)

    refresh_button = Button(label='Refresh', button_type="warning", width=315)

    mode_button = RadioButtonGroup(labels=['Phase', 'Line', 'Field'], active=0)

    phase_button = RadioButtonGroup(
        labels=['z-x', 'z-y', 'x-y', 'x-px', 'x-py', 'y-px', 'y-py', 'px-py'],
        active=0)
    field_button = RadioButtonGroup(
        labels=['z-Ez', 'z-Ex', 'z-Ey', 'z-Bz', 'z-Bx', 'z-By'], active=0)
    line_button = RadioButtonGroup(labels=[
        'rms x', 'rms y', 'rms x Emitt', 'rms y Emitt', 'avg Energy',
        'x β-func', 'y β-func', 'avg x', 'avg y'
    ],
                                   active=0)

    file_input = FileInput(accept='.csv')

    phase_source = ColumnDataSource(data={
        'x': beam.df['z'],
        'y': beam.df['x']
    })
    phase_plot = figure(x_axis_label='z [m]',
                        y_axis_label='x [m]',
                        width=650,
                        height=250)
    phase_plot.scatter('x',
                       'y',
                       source=phase_source,
                       size=0.5,
                       color="#3A5785",
                       alpha=0.5)

    beam.df['x2'] = beam.df['x'] * beam.df['x']
    beam.df['rms x'] = beam.df['x2'].mean()**0.5
    line_source = ColumnDataSource(data={
        'x': beam.df['z'],
        'y': beam.df['rms x']
    })
    line_plot = figure(x_axis_label='z [m]',
                       y_axis_label='rms x [m]',
                       width=650,
                       height=250)
    line_plot.line('x',
                   'y',
                   source=line_source,
                   line_color='#3A5785',
                   line_alpha=0.7,
                   line_width=1)

    field_source = ColumnDataSource(data={
        'x': beam.df['z'],
        'y': beam.df['rms x']
    })
    field_plot = figure(x_axis_label='z [m]',
                        y_axis_label='Ez [MV/m]',
                        width=650,
                        height=250)
    field_plot.line('x',
                    'y',
                    source=field_source,
                    line_color='#3A5785',
                    line_alpha=0.7,
                    line_width=1)

    buttons = row(track_button, refresh_button)
    controls = row(pre, file_input)
    phase_tab = Panel(child=column(phase_button, phase_plot), title='Phase')
    line_tab = Panel(child=column(line_button, line_plot), title='Line')
    field_tab = Panel(child=column(field_button, field_plot), title='Field')
    tabs = Tabs(tabs=[phase_tab, line_tab, field_tab])
    tab = Panel(child=column(controls, tabs, buttons), title='Production')

    def calculate(df):
        df = df.sort_values('z')
        df['avg p'] = window_mean((df['px'] * df['px'] + df['py'] * df['py'] +
                                   df['pz'] * df['pz'])**0.5)
        df['xp'] = df['px'] / df['pz']
        df['yp'] = df['py'] / df['pz']
        df['x xp'] = window_mean(df['x'] * df['xp'])
        df['y yp'] = window_mean(df['y'] * df['yp'])
        df['rms x'] = window_rms(df['x'])
        df['rms y'] = window_rms(df['y'])
        df['rms xp'] = window_rms(df['xp'])
        df['rms yp'] = window_rms(df['yp'])
        df['rms x Emittance'] = (df['rms x']**2 * df['rms xp']**2 -
                                 df['x xp']**2)**0.5
        df['rms y Emittance'] = (df['rms y']**2 * df['rms yp']**2 -
                                 df['y yp']**2)**0.5
        df['avg Energy'] = df['avg p'] - beam.type.mass * rp.c**2 / rp.e / 1e6
        df['x β-function'] = df['rms x']**2 / df['rms x Emittance']
        df['y β-function'] = df['rms y']**2 / df['rms y Emittance']
        df['centroid x'] = window_mean(df['x'])
        df['centroid y'] = window_mean(df['y'])
        df['avg Ez'] = window_mean(df['Ez'])
        df['avg Ex'] = window_mean(df['Ex'])
        df['avg Ey'] = window_mean(df['Ey'])
        df['avg Bz'] = window_mean(df['Bz'])
        df['avg Bx'] = window_mean(df['Bx'])
        df['avg By'] = window_mean(df['By'])
        return df

    def track_handler(new, beam=beam, acc=acc):
        simulation = rp.Simulation(beam, acc)
        simulation.track(n_files=30, path=dirname(__file__) + '/data/')

    def refresh_handler(beam=beam, acc=acc):
        fname = dirname(__file__) + '/data/' + file_input.filename
        df = pd.read_csv(fname, dtype='float32')
        df = calculate(df)
        df = df[df.z >= acc.z_start]
        df = df[df.z <= acc.z_stop]
        if tabs.active == 0:
            if phase_button.active == 0:
                phase_plot.xaxis.axis_label = 'z [m]'
                phase_plot.yaxis.axis_label = 'x [m]'
                phase_source.data = {'x': df['z'], 'y': df['x']}
            if phase_button.active == 1:
                phase_plot.xaxis.axis_label = 'z [m]'
                phase_plot.yaxis.axis_label = 'y [m]'
                phase_source.data = {'x': df['z'], 'y': df['y']}
            if phase_button.active == 2:
                phase_plot.xaxis.axis_label = 'x [m]'
                phase_plot.yaxis.axis_label = 'y [m]'
                phase_source.data = {'x': df['x'], 'y': df['y']}
            if phase_button.active == 3:
                phase_plot.xaxis.axis_label = 'x [m]'
                phase_plot.yaxis.axis_label = 'px [MeV/c]'
                phase_source.data = {'x': df['x'], 'y': df['px']}
            if phase_button.active == 4:
                phase_plot.xaxis.axis_label = 'x [m]'
                phase_plot.yaxis.axis_label = 'py [MeV/c]'
                phase_source.data = {'x': df['x'], 'y': df['py']}
            if phase_button.active == 5:
                phase_plot.xaxis.axis_label = 'y [m]'
                phase_plot.yaxis.axis_label = 'px [MeV/c]'
                phase_source.data = {'x': df['y'], 'y': df['px']}
            if phase_button.active == 6:
                phase_plot.xaxis.axis_label = 'y [m]'
                phase_plot.yaxis.axis_label = 'py [MeV/c]'
                phase_source.data = {'x': df['y'], 'y': df['py']}
            if phase_button.active == 7:
                phase_plot.xaxis.axis_label = 'px [MeV/c]'
                phase_plot.yaxis.axis_label = 'py [MeV/c]'
                phase_source.data = {'x': df['px'], 'y': df['py']}
        if tabs.active == 1:
            if line_button.active == 0:
                line_plot.xaxis.axis_label = 'z [m]'
                line_plot.yaxis.axis_label = 'rms x [m]'
                line_source.data = {'x': df['z'], 'y': df['rms x']}
            if line_button.active == 1:
                line_plot.xaxis.axis_label = 'z [m]'
                line_plot.yaxis.axis_label = 'rms y [m]'
                line_source.data = {'x': df['z'], 'y': df['rms y']}
            if line_button.active == 2:
                line_plot.xaxis.axis_label = 'z [m]'
                line_plot.yaxis.axis_label = 'rms x Emittance [m rad]'
                line_source.data = {'x': df['z'], 'y': df['rms x Emittance']}
            if line_button.active == 3:
                line_plot.xaxis.axis_label = 'z [m]'
                line_plot.yaxis.axis_label = 'rms y  Emittance [m rad]'
                line_source.data = {'x': df['z'], 'y': df['rms y Emittance']}
            if line_button.active == 4:
                line_plot.xaxis.axis_label = 'z [m]'
                line_plot.yaxis.axis_label = 'avg Energy [MeV]'
                line_source.data = {'x': df['z'], 'y': df['avg Energy']}
            if line_button.active == 5:
                line_plot.xaxis.axis_label = 'z [m]'
                line_plot.yaxis.axis_label = 'x β-function [m]'
                line_source.data = {'x': df['z'], 'y': df['x β-function']}
            if line_button.active == 6:
                line_plot.xaxis.axis_label = 'z [m]'
                line_plot.yaxis.axis_label = 'y β-function [m]'
                line_source.data = {'x': df['z'], 'y': df['y β-function']}
            if line_button.active == 7:
                line_plot.xaxis.axis_label = 'z [m]'
                line_plot.yaxis.axis_label = 'Centroid x [m]'
                line_source.data = {'x': df['z'], 'y': df['centroid x']}
            if line_button.active == 8:
                line_plot.xaxis.axis_label = 'z [m]'
                line_plot.yaxis.axis_label = 'Centroid y [m]'
                line_source.data = {'x': df['z'], 'y': df['centroid y']}
        if tabs.active == 2:
            if field_button.active == 0:
                field_plot.xaxis.axis_label = 'z [m]'
                field_plot.yaxis.axis_label = 'Ez [MV/m]'
                field_source.data = {'x': df['z'], 'y': df['avg Ez']}
            if field_button.active == 1:
                field_plot.xaxis.axis_label = 'z [m]'
                field_plot.yaxis.axis_label = 'Ex [MV/m]'
                field_source.data = {'x': df['z'], 'y': df['avg Ex']}
            if field_button.active == 2:
                field_plot.xaxis.axis_label = 'z [m]'
                field_plot.yaxis.axis_label = 'Ey [MV/m]'
                field_source.data = {'x': df['z'], 'y': df['avg Ey']}
            if field_button.active == 3:
                field_plot.xaxis.axis_label = 'z [m]'
                field_plot.yaxis.axis_label = 'Bz [T]'
                field_source.data = {'x': df['z'], 'y': df['avg Bz']}
            if field_button.active == 4:
                field_plot.xaxis.axis_label = 'z [m]'
                field_plot.yaxis.axis_label = 'Bx [T]'
                field_source.data = {'x': df['z'], 'y': df['avg Bx']}
            if field_button.active == 5:
                field_plot.xaxis.axis_label = 'z [m]'
                field_plot.yaxis.axis_label = 'By [T]'
                field_source.data = {'x': df['z'], 'y': df['avg By']}

    track_button.on_click(track_handler)
    refresh_button.on_click(refresh_handler)

    return tab
Пример #16
0
def accelerator_tab(acc):
    '''Creating an accelerator with parameters

    '''

    length = RangeSlider(start=0,
                         end=25,
                         value=(0, 1),
                         step=.1,
                         title='Length [m]',
                         format='0[.]0')
    pre = PreText(text='''Here you can compile an accelerator.''')
    compile_button = Button(label='Compile', button_type="success", height=45)
    element_button = RadioButtonGroup(
        labels=['Accels', 'Solenoids', 'Quadrupoles'], active=0)

    name_input = TextInput(value="Acc. ", title="Name:", width=100)
    z_input = TextInput(value="0", title="z [m]:", width=100)
    field_input = TextInput(value="0", title="MaxField [MV/m]:", width=100)
    file_input = FileInput(accept='.csv, .dat, .txt')
    parameters_input = column(file_input, row(name_input, z_input,
                                              field_input))

    source = ColumnDataSource(data={'z': acc.z, 'Fz': acc.Ez(acc.z)})
    field_plot = figure(x_axis_label='z [m]',
                        y_axis_label='Ez [MV/m]',
                        width=650,
                        height=250)
    field_plot.line('z',
                    'Fz',
                    source=source,
                    line_color='#3A5785',
                    line_alpha=0.7,
                    line_width=1)

    controls = row(column(pre, length, compile_button),
                   column(element_button, parameters_input))

    # Create a row layout
    layout = column(controls, field_plot)
    tab = Panel(child=layout, title='Accelerator')

    def compile_handler(new, acc=acc):
        (acc.z_start, acc.z_stop) = length.value
        acc.z = acc.parameter = np.arange(acc.z_start, acc.z_stop, acc.dz)
        name = name_input.value
        file_name = dirname(__file__) + '/data/' + file_input.filename
        position = float(z_input.value)
        max_field = float(field_input.value)
        if element_button.active == 0:
            acc.add_accel(name, position, max_field, file_name)
            acc.compile()
            source.data = {'z': acc.z, 'Fz': acc.Ez(acc.z)}
            field_input.title = 'MaxField [MV/m]:'
            field_plot.yaxis.axis_label = 'Ez [MV/m]'
        if element_button.active == 1:
            acc.add_solenoid(name, position, max_field, file_name)
            acc.compile()
            source.data = {'z': acc.z, 'Fz': acc.Bz(acc.z)}
            field_input.title = 'MaxField [T]:'
            field_plot.yaxis.axis_label = 'Bz [T]'
        if element_button.active == 2:
            acc.add_quad(name, position, max_field, file_name)
            acc.compile()
            source.data = {'z': acc.z, 'Fz': acc.Gz(acc.z)}
            field_input.title = 'MaxField [T/m]:'
            field_plot.yaxis.axis_label = 'Gz [T/m]'
        print(acc)

    def element_handler(new):
        if new == 0:
            source.data = {'z': acc.z, 'Fz': acc.Ez(acc.z)}
            field_input.title = 'MaxField [MV/m]:'
            name_input.value = 'Acc. '
            field_plot.yaxis.axis_label = 'Ez [MV/m]'
        if new == 1:
            source.data = {'z': acc.z, 'Fz': acc.Bz(acc.z)}
            field_input.title = 'MaxField [T]:'
            name_input.value = 'Sol. '
            field_plot.yaxis.axis_label = 'Bz [T]'
        if new == 2:
            source.data = {'z': acc.z, 'Fz': acc.Gz(acc.z)}
            field_input.title = 'MaxField [T/m]:'
            name_input.value = 'Quad. '
            field_plot.yaxis.axis_label = 'Gz [T/m]'

    compile_button.on_click(compile_handler)
    element_button.on_click(element_handler)

    return tab
Пример #17
0
def widgets():
    from bokeh.io import show
    from bokeh.models import Select, CheckboxButtonGroup, Button, CheckboxGroup, ColorPicker, Dropdown, \
        FileInput, MultiSelect, RadioButtonGroup, RadioGroup, Slider, RangeSlider, TextAreaInput, TextInput, Toggle, \
        Paragraph, PreText, Div

    put_text('Button')
    button = Button(label="Foo", button_type="success")
    show(button)

    put_text('CheckboxButtonGroup')
    checkbox_button_group = CheckboxButtonGroup(
        labels=["Option 1", "Option 2", "Option 3"], active=[0, 1])
    show(checkbox_button_group)

    put_text('CheckboxGroup')
    checkbox_group = CheckboxGroup(labels=["Option 1", "Option 2", "Option 3"],
                                   active=[0, 1])
    show(checkbox_group)

    put_text('ColorPicker')
    color_picker = ColorPicker(color="#ff4466",
                               title="Choose color:",
                               width=200)
    show(color_picker)

    put_text('Dropdown')
    menu = [("Item 1", "item_1"), ("Item 2", "item_2"), None,
            ("Item 3", "item_3")]
    dropdown = Dropdown(label="Dropdown button",
                        button_type="warning",
                        menu=menu)
    show(dropdown)

    put_text('FileInput')
    file_input = FileInput()
    show(file_input)

    put_text('MultiSelect')
    multi_select = MultiSelect(title="Option:",
                               value=["foo", "quux"],
                               options=[("foo", "Foo"), ("bar", "BAR"),
                                        ("baz", "bAz"), ("quux", "quux")])
    show(multi_select)

    put_text('RadioButtonGroup')
    radio_button_group = RadioButtonGroup(
        labels=["Option 1", "Option 2", "Option 3"], active=0)
    show(radio_button_group)

    put_text('RadioGroup')
    radio_group = RadioGroup(labels=["Option 1", "Option 2", "Option 3"],
                             active=0)
    show(radio_group)

    put_text('Select')
    select = Select(title="Option:",
                    value="foo",
                    options=["foo", "bar", "baz", "quux"])
    show(select)

    put_text('Slider')
    slider = Slider(start=0, end=10, value=1, step=.1, title="Stuff")
    show(slider)

    put_text('RangeSlider')
    range_slider = RangeSlider(start=0,
                               end=10,
                               value=(1, 9),
                               step=.1,
                               title="Stuff")
    show(range_slider)

    put_text('TextAreaInput')
    text_input = TextAreaInput(value="default", rows=6, title="Label:")
    show(text_input)

    put_text('TextInput')
    text_input = TextInput(value="default", title="Label:")
    show(text_input)

    put_text('Toggle')
    toggle = Toggle(label="Foo", button_type="success")
    show(toggle)

    put_text('Div')
    div = Div(
        text=
        """Your <a href="https://en.wikipedia.org/wiki/HTML">HTML</a>-supported text is initialized with the <b>text</b> argument.  The
    remaining div arguments are <b>width</b> and <b>height</b>. For this example, those values
    are <i>200</i> and <i>100</i> respectively.""",
        width=200,
        height=100)
    show(div)

    put_text('Paragraph')
    p = Paragraph(
        text="""Your text is initialized with the 'text' argument.  The
    remaining Paragraph arguments are 'width' and 'height'. For this example, those values
    are 200 and 100 respectively.""",
        width=200,
        height=100)
    show(p)

    put_text('PreText')
    pre = PreText(text="""Your text is initialized with the 'text' argument.

    The remaining Paragraph arguments are 'width' and 'height'. For this example,
    those values are 500 and 100 respectively.""",
                  width=500,
                  height=100)
    show(pre)
Пример #18
0
sel_yaxis = Select(title='Select Y Axis',
                   options=sorted(head_columns.keys()),
                   value='Amplifier words')

sel_xaxis = Select(title='Select X Axis',
                   options=sorted(head_columns.keys()),
                   value='Outcome')

options = ['scatter', 'vbar', 'varea', 'line']

sel_plot = Dropdown(label='Select plot type',
                    button_type='success',
                    menu=options,
                    value='line')

add_row = FileInput()

inp_id = TextInput(
    placeholder='Insert ID',
    value='1',
)

sh_ess = TextAreaInput(
    title='Essay %s: 338 words' % inp_id.value,
    rows=30,
    cols=50,
    value=
    """Dear local newspaper, I think effects computers have on people are great learning skills/affects because they give us time to chat with friends/new people, helps us learn about the globe(astronomy) and keeps us out of troble! Thing about! Dont you think so? How would you feel if your teenager is always on the phone with friends! Do you ever time to chat with your friends or buisness partner about things. Well now - there's a new way to chat the computer, theirs plenty of sites on the internet to do so: @ORGANIZATION1, @ORGANIZATION2, @CAPS1, facebook, myspace ect. Just think now while your setting up meeting with your boss on the computer, your teenager is having fun on the phone not rushing to get off cause you want to use it. How did you learn about other countrys/states outside of yours? Well I have by computer/internet, it's a new way to learn about what going on in our time! You might think your child spends a lot of time on the computer, but ask them so question about the economy, sea floor spreading or even about the @DATE1's you'll be surprise at how much he/she knows. Believe it or not the computer is much interesting then in class all day reading out of books. If your child is home on your computer or at a local library, it's better than being out with friends being fresh, or being perpressured to doing something they know isnt right. You might not know where your child is, @CAPS2 forbidde in a hospital bed because of a drive-by. Rather than your child on the computer learning, chatting or just playing games, safe and sound in your home or community place. Now I hope you have reached a point to understand and agree with me, because computers can have great effects on you or child because it gives us time to chat with friends/new people, helps us learn about the globe and believe or not keeps us out of troble. Thank you for listening."""
)
# ------------------------------------------------------------------------------
Пример #19
0
def beam_tab(beam):
    '''Creating a beam with parameters

    '''

    pre = PreText(text='''Here you can generate or upload a beam.''')

    species_button = RadioButtonGroup(
        labels=['electron', 'positron', 'proton', 'antiproton'], active=0)

    select_quantity = Slider(start=1_000, end=100_000,
                             step=1_000, value=1_000,
                             title='Number of particles')

    select_current = Slider(start=0, end=5_000,
                            step=100, value=0,
                            title='Curent [A]')
    # Phase ellipse
    x_input = TextInput(value="0", title="X [m]:", width=75)
    y_input = TextInput(value="0", title="Y [m]:", width=75)
    z_input = TextInput(value="0", title="Z [m]:", width=75)

    px_input = TextInput(value="0", title="Px [MeV/c]:", width=75)
    py_input = TextInput(value="0", title="Py [MeV/c]:", width=75)
    pz_input = TextInput(value="0", title="Pz [MeV/c]:", width=75)

    x_off_input = TextInput(value="0", title="X_off [m]:", width=75)
    y_off_input = TextInput(value="0", title="Y_off [m]:", width=75)
    px_off_input = TextInput(value="0", title="Px_off:", width=60)
    py_off_input = TextInput(value="0", title="Py_off:", width=60)

    phase_ell = column(row(x_input, y_input, z_input, x_off_input),
                       row(px_input, py_input, pz_input, y_off_input))

    select_distribution = RadioButtonGroup(labels=['Uniform', 'Gauss'],
                                           active=0)

    file_input = FileInput(accept='.csv, .ini')
    generate_button = Button(label='Generate', button_type="success",
                             width=155)
    upload_button = Button(label='Upload', button_type="warning",
                           width=155)

    source = ColumnDataSource(data={'x': beam.df['x'], 'y': beam.df['y']})
    p = figure(x_axis_label='x [m]', y_axis_label='y [m]',
               width=410, height=400)
    p.scatter('x', 'y', source=source, size=0.5, color='#3A5785', alpha=0.5)

    controls = column(pre, species_button, select_quantity,
                      select_distribution, select_current,
                      phase_ell, file_input,
                      row(upload_button, generate_button))
    # Create a row layout
    layout = row(controls, p)

    tab = Panel(child=layout, title='Beam')

    def generate_handler(new, beam=beam):
        I = select_current.value
        N = select_quantity.value
        X = float(x_input.value)
        Y = float(y_input.value)
        Z = float(z_input.value)
        Px = float(px_input.value)
        Py = float(py_input.value)
        Pz = float(pz_input.value)
        X_off = float(x_off_input.value)
        Y_off = float(y_off_input.value)
        Px_off = float(px_off_input.value)
        Py_off = float(py_off_input.value)
        if select_distribution.active == 0:
            distribution = rp.Distribution(name='KV', x=X, y=Y, z=Z,
                                           px=Px, py=Py, pz=Pz)
        if select_distribution.active == 1:
            distribution = rp.Distribution(name='GA', x=X, y=Y, z=Z,
                                           px=Px, py=Py, pz=Pz)
        if species_button.active == 0:
            species = rp.electron
        if species_button.active == 1:
            species = rp.positron
        if species_button.active == 2:
            species = rp.proton
        if species_button.active == 3:
            species = rp.antiproton
        beam.type = species
        Q = np.sign(species.charge) * I * Z / rp.c
        beam.generate(distribution, n=N, charge=Q, path=dirname(__file__) + '/data/',
                      x_off=X_off, y_off=Y_off, px_off=Px_off, py_off=Py_off)
        source.data = {'x': beam.df['x'], 'y': beam.df['y']}
        print(beam)

    def upload_handler(new, beam=beam):
        I = select_current.value
        if species_button.active == 0:
            species = rp.electron
        if species_button.active == 1:
            species = rp.positron
        if species_button.active == 2:
            species = rp.proton
        if species_button.active == 3:
            species = rp.antiproton
        beam.type = species
        Q = np.sign(species.charge)* I * (beam.df['z'].max()-beam.df['z'].min()) / rp.c
        beam.upload(dirname(__file__) + '/data/' + file_input.filename, charge=Q,
                    path=dirname(__file__) + '/data/')
        source.data = {'x': beam.df['x'], 'y': beam.df['y']}
        x_input.value = str(np.around(beam.df['x'].max(), 3))
        y_input.value = str(np.around(beam.df['y'].max(), 3))
        z_input.value = str(np.around(beam.df['z'].max(), 3) * 2)
        px_input.value = str(np.around(beam.df['px'].max(), 3))
        py_input.value = str(np.around(beam.df['py'].max(), 3))
        pz_input.value = str(np.around(beam.df['pz'].max(), 3))
        select_quantity.value = beam.n
        print(beam)

    generate_button.on_click(generate_handler)
    upload_button.on_click(upload_handler)

    return tab
Пример #20
0
def create():
    doc = curdoc()
    config = pyzebra.AnatricConfig()

    def _load_config_file(file):
        config.load_from_file(file)

        logfile_textinput.value = config.logfile
        logfile_verbosity.value = config.logfile_verbosity

        filelist_type.value = config.filelist_type
        filelist_format_textinput.value = config.filelist_format
        filelist_datapath_textinput.value = config.filelist_datapath
        filelist_ranges_textareainput.value = "\n".join(
            map(str, config.filelist_ranges))

        crystal_sample_textinput.value = config.crystal_sample
        lambda_textinput.value = config.crystal_lambda
        zeroOM_textinput.value = config.crystal_zeroOM
        zeroSTT_textinput.value = config.crystal_zeroSTT
        zeroCHI_textinput.value = config.crystal_zeroCHI
        ub_textareainput.value = config.crystal_UB

        dataFactory_implementation_select.value = config.dataFactory_implementation
        if config.dataFactory_dist1 is not None:
            dataFactory_dist1_textinput.value = config.dataFactory_dist1
        if config.dataFactory_dist2 is not None:
            dataFactory_dist2_textinput.value = config.dataFactory_dist2
        if config.dataFactory_dist3 is not None:
            dataFactory_dist3_textinput.value = config.dataFactory_dist3
        reflectionPrinter_format_select.value = config.reflectionPrinter_format

        if config.algorithm == "adaptivemaxcog":
            algorithm_params.active = 0
            threshold_textinput.value = config.threshold
            shell_textinput.value = config.shell
            steepness_textinput.value = config.steepness
            duplicateDistance_textinput.value = config.duplicateDistance
            maxequal_textinput.value = config.maxequal
            aps_window_textinput.value = str(
                tuple(map(int, config.aps_window.values())))

        elif config.algorithm == "adaptivedynamic":
            algorithm_params.active = 1
            adm_window_textinput.value = str(
                tuple(map(int, config.adm_window.values())))
            border_textinput.value = str(
                tuple(map(int, config.border.values())))
            minWindow_textinput.value = str(
                tuple(map(int, config.minWindow.values())))
            reflectionFile_textinput.value = config.reflectionFile
            targetMonitor_textinput.value = config.targetMonitor
            smoothSize_textinput.value = config.smoothSize
            loop_textinput.value = config.loop
            minPeakCount_textinput.value = config.minPeakCount
            displacementCurve_textinput.value = "\n".join(
                map(str, config.displacementCurve))

        else:
            raise ValueError("Unknown processing mode.")

    def upload_button_callback(_attr, _old, new):
        with io.BytesIO(base64.b64decode(new)) as file:
            _load_config_file(file)

    upload_div = Div(text="Open .xml config:")
    upload_button = FileInput(accept=".xml", width=200)
    upload_button.on_change("value", upload_button_callback)

    # General parameters
    # ---- logfile
    def logfile_textinput_callback(_attr, _old, new):
        config.logfile = new

    logfile_textinput = TextInput(title="Logfile:", value="logfile.log")
    logfile_textinput.on_change("value", logfile_textinput_callback)

    def logfile_verbosity_callback(_attr, _old, new):
        config.logfile_verbosity = new

    logfile_verbosity = TextInput(title="verbosity:", width=70)
    logfile_verbosity.on_change("value", logfile_verbosity_callback)

    # ---- FileList
    def filelist_type_callback(_attr, _old, new):
        config.filelist_type = new

    filelist_type = Select(title="File List:",
                           options=["TRICS", "SINQ"],
                           width=100)
    filelist_type.on_change("value", filelist_type_callback)

    def filelist_format_textinput_callback(_attr, _old, new):
        config.filelist_format = new

    filelist_format_textinput = TextInput(title="format:", width=290)
    filelist_format_textinput.on_change("value",
                                        filelist_format_textinput_callback)

    def filelist_datapath_textinput_callback(_attr, _old, new):
        config.filelist_datapath = new

    filelist_datapath_textinput = TextInput(title="datapath:")
    filelist_datapath_textinput.on_change(
        "value", filelist_datapath_textinput_callback)

    def filelist_ranges_textareainput_callback(_attr, _old, new):
        ranges = []
        for line in new.splitlines():
            ranges.append(re.findall(r"\b\d+\b", line))
        config.filelist_ranges = ranges

    filelist_ranges_textareainput = TextAreaInput(title="ranges:", rows=1)
    filelist_ranges_textareainput.on_change(
        "value", filelist_ranges_textareainput_callback)

    # ---- crystal
    def crystal_sample_textinput_callback(_attr, _old, new):
        config.crystal_sample = new

    crystal_sample_textinput = TextInput(title="Sample Name:", width=290)
    crystal_sample_textinput.on_change("value",
                                       crystal_sample_textinput_callback)

    def lambda_textinput_callback(_attr, _old, new):
        config.crystal_lambda = new

    lambda_textinput = TextInput(title="lambda:", width=100)
    lambda_textinput.on_change("value", lambda_textinput_callback)

    def ub_textareainput_callback(_attr, _old, new):
        config.crystal_UB = new

    ub_textareainput = TextAreaInput(title="UB matrix:", height=100)
    ub_textareainput.on_change("value", ub_textareainput_callback)

    def zeroOM_textinput_callback(_attr, _old, new):
        config.crystal_zeroOM = new

    zeroOM_textinput = TextInput(title="zeroOM:", width=100)
    zeroOM_textinput.on_change("value", zeroOM_textinput_callback)

    def zeroSTT_textinput_callback(_attr, _old, new):
        config.crystal_zeroSTT = new

    zeroSTT_textinput = TextInput(title="zeroSTT:", width=100)
    zeroSTT_textinput.on_change("value", zeroSTT_textinput_callback)

    def zeroCHI_textinput_callback(_attr, _old, new):
        config.crystal_zeroCHI = new

    zeroCHI_textinput = TextInput(title="zeroCHI:", width=100)
    zeroCHI_textinput.on_change("value", zeroCHI_textinput_callback)

    # ---- DataFactory
    def dataFactory_implementation_select_callback(_attr, _old, new):
        config.dataFactory_implementation = new

    dataFactory_implementation_select = Select(
        title="DataFactory implement.:",
        options=DATA_FACTORY_IMPLEMENTATION,
        width=145,
    )
    dataFactory_implementation_select.on_change(
        "value", dataFactory_implementation_select_callback)

    def dataFactory_dist1_textinput_callback(_attr, _old, new):
        config.dataFactory_dist1 = new

    dataFactory_dist1_textinput = TextInput(title="dist1:", width=75)
    dataFactory_dist1_textinput.on_change(
        "value", dataFactory_dist1_textinput_callback)

    def dataFactory_dist2_textinput_callback(_attr, _old, new):
        config.dataFactory_dist2 = new

    dataFactory_dist2_textinput = TextInput(title="dist2:", width=75)
    dataFactory_dist2_textinput.on_change(
        "value", dataFactory_dist2_textinput_callback)

    def dataFactory_dist3_textinput_callback(_attr, _old, new):
        config.dataFactory_dist3 = new

    dataFactory_dist3_textinput = TextInput(title="dist3:", width=75)
    dataFactory_dist3_textinput.on_change(
        "value", dataFactory_dist3_textinput_callback)

    # ---- BackgroundProcessor

    # ---- DetectorEfficency

    # ---- ReflectionPrinter
    def reflectionPrinter_format_select_callback(_attr, _old, new):
        config.reflectionPrinter_format = new

    reflectionPrinter_format_select = Select(
        title="ReflectionPrinter format:",
        options=REFLECTION_PRINTER_FORMATS,
        width=145,
    )
    reflectionPrinter_format_select.on_change(
        "value", reflectionPrinter_format_select_callback)

    # Adaptive Peak Detection (adaptivemaxcog)
    # ---- threshold
    def threshold_textinput_callback(_attr, _old, new):
        config.threshold = new

    threshold_textinput = TextInput(title="Threshold:", width=145)
    threshold_textinput.on_change("value", threshold_textinput_callback)

    # ---- shell
    def shell_textinput_callback(_attr, _old, new):
        config.shell = new

    shell_textinput = TextInput(title="Shell:", width=145)
    shell_textinput.on_change("value", shell_textinput_callback)

    # ---- steepness
    def steepness_textinput_callback(_attr, _old, new):
        config.steepness = new

    steepness_textinput = TextInput(title="Steepness:", width=145)
    steepness_textinput.on_change("value", steepness_textinput_callback)

    # ---- duplicateDistance
    def duplicateDistance_textinput_callback(_attr, _old, new):
        config.duplicateDistance = new

    duplicateDistance_textinput = TextInput(title="Duplicate Distance:",
                                            width=145)
    duplicateDistance_textinput.on_change(
        "value", duplicateDistance_textinput_callback)

    # ---- maxequal
    def maxequal_textinput_callback(_attr, _old, new):
        config.maxequal = new

    maxequal_textinput = TextInput(title="Max Equal:", width=145)
    maxequal_textinput.on_change("value", maxequal_textinput_callback)

    # ---- window
    def aps_window_textinput_callback(_attr, _old, new):
        config.aps_window = dict(
            zip(("x", "y", "z"), re.findall(r"\b\d+\b", new)))

    aps_window_textinput = TextInput(title="Window (x, y, z):", width=145)
    aps_window_textinput.on_change("value", aps_window_textinput_callback)

    # Adaptive Dynamic Mask Integration (adaptivedynamic)
    # ---- window
    def adm_window_textinput_callback(_attr, _old, new):
        config.adm_window = dict(
            zip(("x", "y", "z"), re.findall(r"\b\d+\b", new)))

    adm_window_textinput = TextInput(title="Window (x, y, z):", width=145)
    adm_window_textinput.on_change("value", adm_window_textinput_callback)

    # ---- border
    def border_textinput_callback(_attr, _old, new):
        config.border = dict(zip(("x", "y", "z"), re.findall(r"\b\d+\b", new)))

    border_textinput = TextInput(title="Border (x, y, z):", width=145)
    border_textinput.on_change("value", border_textinput_callback)

    # ---- minWindow
    def minWindow_textinput_callback(_attr, _old, new):
        config.minWindow = dict(
            zip(("x", "y", "z"), re.findall(r"\b\d+\b", new)))

    minWindow_textinput = TextInput(title="Min Window (x, y, z):", width=145)
    minWindow_textinput.on_change("value", minWindow_textinput_callback)

    # ---- reflectionFile
    def reflectionFile_textinput_callback(_attr, _old, new):
        config.reflectionFile = new

    reflectionFile_textinput = TextInput(title="Reflection File:", width=145)
    reflectionFile_textinput.on_change("value",
                                       reflectionFile_textinput_callback)

    # ---- targetMonitor
    def targetMonitor_textinput_callback(_attr, _old, new):
        config.targetMonitor = new

    targetMonitor_textinput = TextInput(title="Target Monitor:", width=145)
    targetMonitor_textinput.on_change("value",
                                      targetMonitor_textinput_callback)

    # ---- smoothSize
    def smoothSize_textinput_callback(_attr, _old, new):
        config.smoothSize = new

    smoothSize_textinput = TextInput(title="Smooth Size:", width=145)
    smoothSize_textinput.on_change("value", smoothSize_textinput_callback)

    # ---- loop
    def loop_textinput_callback(_attr, _old, new):
        config.loop = new

    loop_textinput = TextInput(title="Loop:", width=145)
    loop_textinput.on_change("value", loop_textinput_callback)

    # ---- minPeakCount
    def minPeakCount_textinput_callback(_attr, _old, new):
        config.minPeakCount = new

    minPeakCount_textinput = TextInput(title="Min Peak Count:", width=145)
    minPeakCount_textinput.on_change("value", minPeakCount_textinput_callback)

    # ---- displacementCurve
    def displacementCurve_textinput_callback(_attr, _old, new):
        maps = []
        for line in new.splitlines():
            maps.append(re.findall(r"\d+(?:\.\d+)?", line))
        config.displacementCurve = maps

    displacementCurve_textinput = TextAreaInput(
        title="Displ. Curve (2θ, x, y):", width=145, height=100)
    displacementCurve_textinput.on_change(
        "value", displacementCurve_textinput_callback)

    def algorithm_tabs_callback(_attr, _old, new):
        if new == 0:
            config.algorithm = "adaptivemaxcog"
        else:
            config.algorithm = "adaptivedynamic"

    algorithm_params = Tabs(tabs=[
        Panel(
            child=column(
                row(threshold_textinput, shell_textinput, steepness_textinput),
                row(duplicateDistance_textinput, maxequal_textinput,
                    aps_window_textinput),
            ),
            title="Peak Search",
        ),
        Panel(
            child=column(
                row(adm_window_textinput, border_textinput,
                    minWindow_textinput),
                row(reflectionFile_textinput, targetMonitor_textinput,
                    smoothSize_textinput),
                row(loop_textinput, minPeakCount_textinput,
                    displacementCurve_textinput),
            ),
            title="Dynamic Integration",
        ),
    ])
    algorithm_params.on_change("active", algorithm_tabs_callback)

    def process_button_callback():
        with tempfile.TemporaryDirectory() as temp_dir:
            temp_file = temp_dir + "/config.xml"
            config.save_as(temp_file)
            pyzebra.anatric(temp_file,
                            anatric_path=doc.anatric_path,
                            cwd=temp_dir)

            with open(os.path.join(temp_dir, config.logfile)) as f_log:
                output_log.value = f_log.read()

            with open(os.path.join(temp_dir,
                                   config.reflectionPrinter_file)) as f_res:
                output_res.value = f_res.read()

    process_button = Button(label="Process", button_type="primary")
    process_button.on_click(process_button_callback)

    output_log = TextAreaInput(title="Logfile output:",
                               height=320,
                               width=465,
                               disabled=True)
    output_res = TextAreaInput(title="Result output:",
                               height=320,
                               width=465,
                               disabled=True)
    output_config = TextAreaInput(title="Current config:",
                                  height=320,
                                  width=465,
                                  disabled=True)

    general_params_layout = column(
        row(column(Spacer(height=2), upload_div), upload_button),
        row(logfile_textinput, logfile_verbosity),
        row(filelist_type, filelist_format_textinput),
        filelist_datapath_textinput,
        filelist_ranges_textareainput,
        row(crystal_sample_textinput, lambda_textinput),
        ub_textareainput,
        row(zeroOM_textinput, zeroSTT_textinput, zeroCHI_textinput),
        row(
            dataFactory_implementation_select,
            dataFactory_dist1_textinput,
            dataFactory_dist2_textinput,
            dataFactory_dist3_textinput,
        ),
        row(reflectionPrinter_format_select),
    )

    tab_layout = row(
        general_params_layout,
        column(output_config, algorithm_params, row(process_button)),
        column(output_log, output_res),
    )

    async def update_config():
        output_config.value = config.tostring()

    doc.add_periodic_callback(update_config, 1000)

    return Panel(child=tab_layout, title="hdf anatric")
        p.xaxis.axis_label = 'Time'
        p.yaxis.visible = False
        p.add_tools(hover)
    curdoc().clear()
    div.text = '<h1>' + file_input.filename + '</h1>'

    curdoc().add_root(column(file_input, div, radio_button,
                             get_current_plot()))


def get_current_plot():
    if radio_button.active == 0:
        return p1
    elif radio_button.active == 1:
        return p2
    else:
        return p3


def change_current_plot(attr, old, new):
    curdoc().clear()
    curdoc().add_root(column(file_input, div, radio_button,
                             get_current_plot()))


file_input = FileInput(accept=".json")
file_input.on_change('value', load_trace_data)
radio_button.on_change('active', change_current_plot)

curdoc().add_root(file_input)
Пример #22
0
    def plot_trends_scatter_bokeh(self):
        lon = self.lon
        lat = self.lat
        dw = lon[-1] - lon[0]
        dh = lat[0] - lat[-1]

        self.dfs = pd.DataFrame.from_dict(
            data=dict(LONGITUDE=[], LATITUDE=[], NAME=[], slope=[], id2=[]))

        p = figure(plot_width=int(400. * dw / dh),
                   plot_height=400,
                   match_aspect=True,
                   tools="pan,wheel_zoom,box_zoom,tap,reset",
                   output_backend="webgl")

        ##--- Create a modified version of seismic colormap

        from bokeh.models import LinearColorMapper, ColorBar
        import matplotlib.cm as mcm
        import matplotlib.colors as mcol

        fcmap = mcm.get_cmap('seismic')
        cmap_mod = [fcmap(i) for i in np.linspace(0, 1, 15)]
        cmap_mod[7] = mcm.get_cmap('RdYlGn')(
            0.5)  # replace white in the middle by the yellow of RdYlGn
        scmap = mcol.LinearSegmentedColormap.from_list(
            "", cmap_mod)  # recreate a colormap
        ## Extract 256 colors from the new colormap and convert them to hex
        cmap_mod = [scmap(i) for i in np.linspace(0, 1, 256)]
        cmap_mod = [
            "#%02x%02x%02x" % (int(255 * r), int(255 * g), int(255 * b))
            for r, g, b, _ in cmap_mod
        ]
        ## Make a fake colormapper to start
        ## based on the previous 256 colors (needed because it does not make linear interpolation between colors)
        self.sn_max = 0.001
        color_mapper = LinearColorMapper(palette=cmap_mod,
                                         low=-self.sn_max,
                                         high=self.sn_max)

        ##--- Select CSV file to read

        def upload_input_csv(attr, old, new):
            ## Read, decode and save input data to tmp file
            print("Data upload succeeded")
            print("file_input.filename=", file_input.filename)
            data = base64.b64decode(file_input.value).decode('utf8')
            with open(self.app_dir / 'data/tmp_input.csv', 'w') as f:
                f.write(data)

            ## Get csv meta data and init plot
            meta = {
                l.split(':')[0]: l.split(':')[1]
                for l in data.split('\n') if l.startswith('#')
            }
            self.hf = h5py.File(
                self.app_dir / 'data' / meta['#input_extract_cache_file'], 'r')
            if '#input_breaks_pickle_file' in meta.keys():
                self.b_breaks = True
                self.df_breaks = pd.read_pickle(
                    self.app_dir / 'data' / meta['#input_breaks_pickle_file'])
                ## Init line to display timeseries segment
                # timeseries segment
                segment_line = p2.line(x='dates',
                                       y='var',
                                       source=segment_source,
                                       line_color='red')
                # vertical lines for breaks
                p2.segment(x0="x",
                           y0="y0",
                           x1="x",
                           y1="y1",
                           line_color="black",
                           line_dash='dashed',
                           line_width=2,
                           source=breaks_source)
                # Add bottom horizontal line
                #p2.line(x="x", y="y0", line_color="#fb8072", line_width=2, source=breaks_source)
                #p2.diamond(x="x", y="y0", color="#fb8072", size=12, source=breaks_source)
            else:
                self.b_breaks = False
            # Get date range from h5 file
            self.dates = self.hf['meta/ts_dates'][:].view(
                'datetime64[s]').tolist()
            self.point_names = [
                i.decode('utf8') for i in self.hf['meta/point_names'][:]
            ]
            d_init = [
                d for d in self.dates if d.year != 1970
            ]  # Some date are set to 1970 (ie stored as 0 ? to be checked)
            ts_source.data = dict(dates=d_init, var=np.zeros_like(d_init))

            ## Read tmp file and update select widget with available variables
            df = pd.read_csv(self.app_dir / 'data/tmp_input.csv',
                             sep=';',
                             comment='#')
            in_var = [i for i in df.columns if i.endswith('_sn')]
            df = df.dropna(subset=in_var)
            in_var = [i.replace('_sn', '') for i in in_var]
            print(in_var)
            select.disabled = False
            select.options = in_var
            select.value = in_var[0]

            ## If there is only one variable in the csv, plot it directly
            if len(in_var) == 1:
                read_data_for_plotting(in_var[0])

        file_input = FileInput(
            accept=".plot.csv")  # comma separated list if any
        file_input.on_change('value', upload_input_csv)

        ## Add variable selection
        def select_variable(attr, old, new):
            read_data_for_plotting(new)

        select = Select(title="Variable in csv:", disabled=True)
        select.on_change('value', select_variable)

        ##--- Add land mask

        # must give a vector of image data for image parameter
        mask = self.rebin(
            self.mask,
            (int(self.mask.shape[0] / 5), int(self.mask.shape[1] / 5)))
        #p.image(image=[np.flipud(self.mask[::20,::20])],
        p.image(image=[np.flipud(mask)],
                x=lon[0],
                y=lat[-1],
                dw=dw,
                dh=dh,
                palette=('#FFFFFF', '#EEEEEE', '#DDDDDD', '#CCCCCC', '#BBBBBB',
                         '#AAAAAA', '#999999', '#888888'),
                level="image")
        p.grid.grid_line_width = 0.5

        ##--- Read selected data, filter and convert to ColumnDataSource

        def read_data_for_plotting(var):
            ## Get the variable from the input h5 cache file
            self.ts = self.hf['vars/' + var][:, 0, :].T

            ## Get data from input csv
            var = var + '_sn'
            if self.b_breaks:
                df = pd.read_csv(self.app_dir / 'data/tmp_input.csv',
                                 sep=';',
                                 comment='#',
                                 parse_dates=['start_date', 'end_date'])
            else:
                df = pd.read_csv(self.app_dir / 'data/tmp_input.csv',
                                 sep=';',
                                 comment='#')
            id_sites_in_cache_file = {
                s: i
                for i, s in enumerate(self.point_names)
            }
            df['id2'] = df['NAME'].map(id_sites_in_cache_file)
            df = df.dropna(subset=[var])

            if self.b_breaks:
                # better use loc[] to select part of a df that will be modified afterward to be sure to have a copy
                self.dfs = df.loc[:, [
                    'LONGITUDE', 'LATITUDE', 'NAME', var, 'id2', 'lvl',
                    'start_date', 'end_date'
                ]]
            else:
                self.dfs = df.loc[:, [
                    'LONGITUDE', 'LATITUDE', 'NAME', var, 'id2'
                ]]
                self.dfs['lvl'] = np.zeros_like(self.dfs[var])
                self.dfs['start_date'] = np.zeros_like(self.dfs[var])
                self.dfs['end_date'] = np.zeros_like(self.dfs[var])

            self.dfs = self.dfs.rename(columns={var: 'slope'})

            source.data = ColumnDataSource.from_df(self.dfs)

            self.sn_max = np.abs(np.nanmax(self.dfs['slope']))
            color_mapper.low = -self.sn_max
            color_mapper.high = self.sn_max

            slider.end = self.sn_max * 1000.
            slider.step = self.sn_max * 1000. / 20.
            slider.value = (0.0, self.sn_max * 1000.)
            #slider.disabled=False
            #slider.bar_color='#e6e6e6'

            if self.b_breaks:
                slider_date.start = self.dates[0]
                slider_date.end = self.dates[-1]
                slider_date.value = (self.dates[0], self.dates[-1])
                slider_date.visible = True

        ##--- Add scatter

        ## Create source that will be populated according to slider
        source = ColumnDataSource(data=dict(LONGITUDE=[],
                                            LATITUDE=[],
                                            NAME=[],
                                            slope=[],
                                            id2=[],
                                            lvl=[],
                                            start_date=[],
                                            end_date=[]))
        #source = ColumnDataSource(dfs)

        scatter_renderer = p.scatter(x='LONGITUDE',
                                     y='LATITUDE',
                                     size=12,
                                     color={
                                         'field': 'slope',
                                         'transform': color_mapper
                                     },
                                     source=source)

        color_bar = ColorBar(color_mapper=color_mapper, label_standoff=12)
        p.add_layout(color_bar, 'right')

        ## Add hover tool that only act on scatter and not on the background land mask
        p.add_tools(
            HoverTool(
                #tooltips=[("A", "@A"), ("B", "@B"), ("C", "@C")], mode = "vline"
                renderers=[scatter_renderer],
                mode='mouse'))

        ##--- Add slider

        slider = RangeSlider(start=0.0,
                             end=self.sn_max * 1000.,
                             value=(0.0, self.sn_max * 1000.),
                             step=self.sn_max * 1000. / 20.,
                             title="Trend threshold [10e-3]")
        slider_date = DateRangeSlider(title="Date range: ",
                                      start=dt.date(1981, 1, 1),
                                      end=dt.date.today(),
                                      value=(dt.date(1981, 1,
                                                     1), dt.date.today()),
                                      step=1,
                                      visible=False)

        ## Slider Python callback
        def update_scatter(attr, old, new):
            # new = new slider value
            #source.data = ColumnDataSource.from_df(self.dfs.loc[ (np.abs(self.dfs['slope']) >= 0.001*new[0]) &
            #(np.abs(self.dfs['slope']) <= 0.001*new[1]) ])
            if self.b_breaks:
                slope_sel = slider.value
                date_sel = [
                    pd.to_datetime(d, unit='ms') for d in slider_date.value
                ]
                source.data = ColumnDataSource.from_df(self.dfs.loc[
                    (np.abs(self.dfs['slope']) >= 0.001 * slope_sel[0])
                    & (np.abs(self.dfs['slope']) <= 0.001 * slope_sel[1]) &
                    (self.dfs['start_date'] >= date_sel[0]) &
                    (self.dfs['end_date'] <= date_sel[1])])
            else:
                slope_sel = slider.value
                source.data = ColumnDataSource.from_df(self.dfs.loc[
                    (np.abs(self.dfs['slope']) >= 0.001 * slope_sel[0])
                    & (np.abs(self.dfs['slope']) <= 0.001 * slope_sel[1])])

        slider.on_change('value', update_scatter)
        slider_date.on_change('value', update_scatter)

        ##--- Add time series of selected point

        pw = int(400. * dw / dh)
        ph = 200
        p2 = figure(plot_width=pw,
                    plot_height=ph,
                    tools="pan,wheel_zoom,box_zoom,reset",
                    output_backend="webgl",
                    x_axis_type="datetime",
                    title='---')

        p2.add_tools(
            HoverTool(
                tooltips=[
                    ("Date",
                     "@dates{%Y-%m-%d}"),  # must specify desired format here
                    ("Value", "@var")
                ],
                formatters={"@dates": "datetime"},
                mode='vline'))

        ## Create source and plot it

        #ts_source = ColumnDataSource(data=dict(dates=[], var=[]))
        d_init = [dt.datetime(1981, 9, 20), dt.datetime(2020, 6, 30)]
        ts_source = ColumnDataSource(
            data=dict(dates=d_init, var=np.zeros_like(d_init)))
        segment_source = ColumnDataSource(
            data=dict(dates=d_init, var=np.zeros_like(d_init)))
        breaks_source = ColumnDataSource(data=dict(x=[], y0=[], y1=[]))
        # Full timeseries line
        p2.line(x='dates', y='var', source=ts_source)

        ## Add satellite periods
        # Sensor dates
        sensor_dates = []
        sensor_dates.append(['NOAA7', ('20-09-1981', '31-12-1984')])
        sensor_dates.append(['NOAA9', ('20-03-1985', '10-11-1988')])
        sensor_dates.append(['NOAA11', ('30-11-1988', '20-09-1994')])
        sensor_dates.append(['NOAA14', ('10-02-1995', '10-03-2001')])
        sensor_dates.append(['NOAA16', ('20-03-2001', '10-09-2002')])
        sensor_dates.append(['NOAA17', ('20-09-2002', '31-12-2005')])
        sensor_dates.append(['VGT1', ('10-04-1998', '31-01-2003')])
        sensor_dates.append(['VGT2', ('31-01-2003', '31-05-2014')])
        sensor_dates.append(['PROBAV', ('31-10-2013', '30-06-2020')])
        sensor_dates = [[
            v[0], [dt.datetime.strptime(i, "%d-%m-%Y") for i in v[1]]
        ] for v in sensor_dates]

        import itertools
        from bokeh.palettes import Category10 as palette
        colors = itertools.cycle(palette[10])
        top_ba = []
        bottom_ba = []
        for v, color in zip(sensor_dates, colors):
            if 'VGT' not in v[0]:
                top_ba.append(
                    BoxAnnotation(top=ph,
                                  top_units='screen',
                                  bottom=int(ph / 2),
                                  bottom_units='screen',
                                  left=v[1][0],
                                  right=v[1][1],
                                  fill_alpha=0.2,
                                  fill_color=color))
            else:
                bottom_ba.append(
                    BoxAnnotation(top=int(ph / 2),
                                  top_units='screen',
                                  bottom=0,
                                  bottom_units='screen',
                                  left=v[1][0],
                                  right=v[1][1],
                                  fill_alpha=0.2,
                                  fill_color=color))
        for ba in top_ba:
            p2.add_layout(ba)
        for ba in bottom_ba:
            p2.add_layout(ba)

        def update_ts(attr, old, new):
            """
            attr: 'indices'
            old (list): the previous selected indices
            new (list): the new selected indices
            """
            if 0:
                print(p2.width, p2.height)
                print(p2.frame_width, p2.frame_height)
                print(p2.inner_width, p2.inner_height)
                print(p2.x_range.start, p2.x_range.end, p2.x_scale)
                print(p2.y_range.start, p2.y_range.end, p2.y_scale)

            if len(new) > 0:
                ## Update line with the last index because this is the last drawn point that is visible
                site_id = int(source.data['id2'][new[-1]])
                ts_source.data = dict(dates=self.dates, var=self.ts[site_id])
                if self.b_breaks:
                    ## Add segment
                    multi_idx = (source.data['NAME'][new[-1]],
                                 str(source.data['lvl'][new[-1]]))
                    segment_slice = self.df_breaks.loc[multi_idx]['x'].astype(
                        'int')
                    segment_source.data = dict(
                        dates=[self.dates[i] for i in segment_slice],
                        var=self.ts[site_id][segment_slice])
                    ## Add breaks
                    xb = [
                        pd.to_datetime(i) for i in self.df_breaks.loc[
                            source.data['NAME'][new[-1]]]['bp_date'].values
                        if not pd.isnull(i)
                    ]
                    # Add first and last dates
                    xb = [pd.to_datetime(self.dates[0])
                          ] + xb + [pd.to_datetime(self.dates[-1])]
                    y0b = np.nanmin(self.ts[site_id]) * np.ones(len(xb))
                    y1b = np.nanmax(self.ts[site_id]) * np.ones(len(xb))
                    breaks_source.data = dict(x=xb, y0=y0b, y1=y1b)

                ## Update BoxAnnotation
                ph = p2.inner_height
                for ba in top_ba:
                    ba.top = ph
                    ba.bottom = int(ph / 2)
                for ba in bottom_ba:
                    ba.top = int(ph / 2)
                    ba.bottom = 0

                ## Update p2 title text with the name of the site
                p2.title.text = 'SITE : {} (#{})'.format(
                    source.data['NAME'][new[-1]], source.data['id2'][new[-1]])

        source.selected.on_change('indices', update_ts)

        ##--- Save html file

        #save(column(slider, p, p2))
        #save(p)

        ##--- Serve the file

        curdoc().add_root(
            column(file_input, select, slider, slider_date, p, p2))
        curdoc().title = "Quality monitoring"
Пример #23
0
"""

from bokeh.plotting import figure
import pandas as pd
import numpy as np
from bokeh.io import curdoc
from bokeh.models import Select, ColumnDataSource, FileInput, HoverTool, DataTable, TableColumn, DateFormatter
from bokeh.layouts import column, row
import json
from base64 import b64decode
from datetime import time
from zipfile import ZipFile
import os

# These are the intital empty widgets when the server first loads
file_input = FileInput(accept='.json, .zip')
artist_plot = figure(title='Artist X Stream Time', plot_width=1500)
select_artist = Select(title='Select an Artist', value="", options=[])
track_plot = figure(title='Track X Stream Time', plot_width=1500)
select_track = Select(title='Select a Track', value="", options=[])

# Adds the empty widgets to the document
layout = column(row(file_input), row(select_artist), row(artist_plot),
                row(select_track), row(track_plot))
curdoc().add_root(layout)


def read_file(attrname, old, new):
    """
    Takes the file from file input and converts into a JSON string
    """
Пример #24
0
def prepare_server(doc,
                   input_data,
                   cell_stack,
                   cell_markers=None,
                   default_cell_marker=None):
    @lru_cache()
    def image_markers(lower=False, mapping=False):
        if mapping:
            return {
                y: j
                for j, y in sorted(
                    ((i, x) for i, x in enumerate(
                        image_markers(lower=lower, mapping=False))),
                    key=lambda x: x[1].lower(),
                )
            }
        if lower:
            return [
                x.lower() for x in image_markers(lower=False, mapping=False)
            ]
        return (cell_markers if cell_markers is not None else
                [f"Marker {i + 1}" for i in range(cell_stack.shape[1])])

    # Data sources
    ###########################################################################

    def prepare_data(input_data):
        data = input_data.copy()
        if "contour" in data and not all(x in data
                                         for x in ["contour_x", "contour_y"]):
            contour = parse_contour(data["contour"])
            data["contour_x"] = contour[0]
            data["contour_y"] = contour[1]
        if "marked" not in data:
            data["marked"] = np.full(data.shape[0], "")
        source.data = data

    source = ColumnDataSource(data={})
    prepare_data(input_data)
    image_source = ColumnDataSource(
        data=dict(image=[], dw=[], dh=[], contour_x=[], contour_y=[]))

    # Cell picture plot
    ###########################################################################

    def add_outline():
        data = source.data
        if all(x in data for x in ["contour_x", "contour_y"]):
            cell_outline = cell_figure.patches(
                xs="contour_x",
                ys="contour_y",
                fill_color=None,
                line_color="red",
                name="cell_outline",
                source=image_source,
            )
            cell_outline.level = "overlay"
        else:
            cell_outline = cell_figure.select(name="cell_outline")
            for x in cell_outline:
                cell_figure.renderers.remove(x)

    default_cell_marker = (0 if default_cell_marker is None else image_markers(
        mapping=True)[default_cell_marker])
    cell_markers_select = Select(
        value=str(default_cell_marker),
        options=list(
            (str(i), x) for x, i in image_markers(mapping=True).items()),
        title="Marker cell image",
    )
    cell_marker_input = AutocompleteInput(
        completions=list(image_markers()) + list(image_markers(lower=True)),
        min_characters=1,
        placeholder="Search for marker",
    )
    cell_slider = RangeSlider(start=0,
                              end=1,
                              value=(0, 1),
                              orientation="vertical",
                              direction="rtl")
    metric_select = RadioButtonGroup(active=0, labels=CELL_IMAGE_METRICS[0])
    stats = PreText(text="", width=100)

    cell_mapper = bokeh.models.mappers.LinearColorMapper(viridis(20),
                                                         low=0,
                                                         high=1000,
                                                         high_color=None)
    cell_color_bar = ColorBar(color_mapper=cell_mapper,
                              width=12,
                              location=(0, 0))
    cell_figure = figure(
        plot_width=450,
        plot_height=350,
        tools="pan,wheel_zoom,reset",
        toolbar_location="left",
    )
    cell_image = cell_figure.image(
        image="image",
        color_mapper=cell_mapper,
        x=0,
        y=0,
        dw="dw",
        dh="dh",
        source=image_source,
    )
    add_outline()
    cell_figure.add_layout(cell_color_bar, "right")

    # Edit data of selected cells
    ###########################################################################

    marker_edit_container = column()
    marker_edit_instances = []

    def add_marker_edit_callback():
        editor = ColumnEditor(
            source,
            marker_edit_container,
            log_widget=edit_selecton_log,
            editor_delete_callback=delete_marker_edit_callback,
            external_edit_callback=edit_selection_callback,
        )
        marker_edit_instances.append(editor)

    def delete_marker_edit_callback(editor):
        idx = next(i for i, x in enumerate(marker_edit_instances)
                   if x is editor)
        del marker_edit_instances[idx]

    file_name_text = Div()

    add_marker_edit_button = Button(label="+",
                                    button_type="success",
                                    align=("start", "end"),
                                    width=50)
    add_marker_edit_button.on_click(add_marker_edit_callback)

    edit_selection_submit = Button(label="Submit change",
                                   button_type="primary",
                                   align=("start", "end"))
    download_button = Button(label="Download edited data",
                             button_type="success",
                             align=("start", "end"))
    download_button.js_on_click(
        CustomJS(args=dict(source=source), code=DOWNLOAD_JS))

    edit_selecton_log = TextAreaInput(value="",
                                      disabled=True,
                                      css_classes=["edit_log"],
                                      cols=30,
                                      rows=10)

    upload_file_input = FileInput(accept="text/csv", align=("end", "end"))

    # Cell table
    ###########################################################################

    default_data_table_cols = [
        TableColumn(field="marked", title="Seen", width=20)
    ]

    data_table = DataTable(source=source,
                           columns=default_data_table_cols,
                           width=800)

    # Callbacks for buttons and widgets
    ###########################################################################

    def cell_slider_change(attrname, old, new):
        cell_mapper.low = new[0]
        cell_mapper.high = new[1]

    def selection_change(attrname, old, new):
        selected = source.selected.indices
        data = source.data
        if not selected:
            return
        mean_image = CELL_IMAGE_METRICS[1][metric_select.active](
            cell_stack[selected,
                       int(cell_markers_select.value), :, :], axis=0)
        image_data = {
            "image": [mean_image],
            "dw": [cell_stack.shape[2]],
            "dh": [cell_stack.shape[3]],
        }
        for coord in ["contour_x", "contour_y"]:
            try:
                image_data[coord] = list(data[coord][selected])
            except KeyError:
                pass
        image_source.data = image_data
        image_extr = round_signif(mean_image.min()), round_signif(
            mean_image.max())
        cell_slider.start = image_extr[0]
        cell_slider.end = image_extr[1]
        cell_slider.step = round_signif((image_extr[1] - image_extr[0]) / 50)
        cell_slider.value = image_extr
        stats.text = "n cells: " + str(len(selected))

    def autocomplete_cell_change(attrname, old, new):
        try:
            idx = image_markers(mapping=True)[new]
        except KeyError:
            try:
                idx = image_markers(lower=True, mapping=True)[new]
            except KeyError:
                return
        cell_markers_select.value = str(idx)
        cell_marker_input.value = None

    def data_change(attrname, old, new):
        new_keys = [n for n in new.keys() if n not in set(old.keys())]
        for n in new_keys:
            data_table.columns.append(TableColumn(field=n, title=n))

    def edit_selection_submit_click():
        for x in marker_edit_instances:
            x.edit_callback()

    def edit_selection_callback():
        idx = source.selected.indices
        try:
            if len(idx) == 1 and all(
                    source.data[x.widgets["input_col"].value][idx] != "NA"
                    for x in marker_edit_instances):
                source.selected.indices = [idx[0] + 1]
        except KeyError:
            pass

    def upload_file_callback(attrname, old, new):
        try:
            data_text = b64decode(new)
            data = pd.read_csv(BytesIO(data_text))
        except Exception:
            file_name_text.text = f"Error loading file {upload_file_input.filename}"
            return
        file_name_text.text = f"Editing file {upload_file_input.filename}"
        # Have to regenerate contours
        try:
            del data["contour_x"]
            del data["contour_y"]
        except KeyError:
            pass
        data_table.columns = default_data_table_cols
        prepare_data(data)
        add_outline()

    source.selected.on_change("indices", selection_change)
    source.on_change("data", data_change)
    cell_slider.on_change("value", cell_slider_change)
    metric_select.on_change("active", selection_change)
    cell_markers_select.on_change("value", selection_change)
    cell_marker_input.on_change("value", autocomplete_cell_change)
    edit_selection_submit.on_click(edit_selection_submit_click)
    upload_file_input.on_change("value", upload_file_callback)

    style_div = Div(text=CUSTOM_CSS)

    # set up layout
    layout = column(
        row(
            column(data_table),
            column(
                cell_markers_select,
                cell_marker_input,
                metric_select,
                row(cell_figure, cell_slider),
                stats,
            ),
        ),
        file_name_text,
        marker_edit_container,
        add_marker_edit_button,
        row(edit_selection_submit, download_button, upload_file_input),
        edit_selecton_log,
        style_div,
    )

    doc.add_root(layout)
    doc.title = "Cell classifier"
Пример #25
0
from bokeh.layouts import column
from bokeh.models import CustomJS, Div, FileInput
from bokeh.plotting import output_file, show

# Set up widgets
file_input = FileInput(accept=".csv,.json")
para = Div(text="<h1>FileInput Values:</h1><p>filename:<p>b64 value:")

# Create CustomJS callback to display file_input attributes on change
callback = CustomJS(args=dict(para=para, file_input=file_input),
                    code="""
    para.text = "<h1>FileInput Values:</h1><p>filename: " + file_input.filename  + "<p>b64 value: " + file_input.value
""")

# Attach callback to FileInput widget
file_input.js_on_change('change', callback)

output_file("file_input.html")

show(column(file_input, para))
Пример #26
0
class UIClass:
    def __init__(self):
        self.input_df = pd.DataFrame({
            'x': ['2010-01-01'] * DF_NUM_PREVIEW_ROWS,
            'y': [0] * DF_NUM_PREVIEW_ROWS
        })
        self.forecasted_df = None
        self.datefmt = DateFormatter(format='%m-%d-%Y')
        self.inputs = None
        self.x_range = [0, 10]
        self.demand_plot = figure(
            x_range=self.x_range,
            x_axis_type="datetime",
            tools=["pan", 'wheel_zoom'])  #,wheel_zoom,box_zoom,reset,resize")

        self.plot_data_source = ColumnDataSource(
            data=self.input_df)  #dict(x=[0], y=[0])
        self.line1 = self.demand_plot.line(x='x',
                                           y='y',
                                           source=self.plot_data_source,
                                           line_color='blue',
                                           name='line1')
        self.demand_plot.xaxis.formatter = DatetimeTickFormatter(
            days="%d %b %Y", hours="")
        self.demand_plot.axis.minor_tick_line_color = None
        self.demand_plot.xaxis[
            0].ticker.desired_num_ticks = 10  #num_minor_ticks = 0
        self.demand_plot.xaxis.major_label_orientation = radians(
            30)  # from math import radians

        # Set up widgets
        self.data_source_selector = Select(
            title='Step 1/5: Select Data',
            value='Not Selected',
            options=['Not Selected', 'Use Example Data', 'Upload Data'])
        self.file_input = FileInput(accept='.csv,.xlsx')
        self.data_table = DataTable(
            height=DATATABLE_PREVIEW_HEIGHT,
            width=DATATABLE_PREVIEW_WIDTH,
            fit_columns=False,
            index_position=None,
            margin=(0, 15, 0, 15),  #aspect_ratio=0.5,
            #default_size=50
        )
        self.data_preview_paragraph = Paragraph(text='Data Preview:',
                                                margin=(0, 15, 0, 15))
        self.values_col_selector = Select(
            title='Step 2/5: Select column with demand values',
            value='Not Selected',
            options=['Not Selected'])
        self.product_id_col_selector = Select(
            title='Step 3/5: Select column with product ID',
            value='Not Selected',
            options=['Not Selected'])
        self.date_col_selector = Select(title="Step 4/5: Select date column",
                                        value='Not Selected',
                                        options=['Not Selected'])
        self.last_date_picker = DatePicker(
            title='Select the date of last observation',
            max_date=datetime.datetime.date(pd.to_datetime("today")),
            value=datetime.datetime.date(pd.to_datetime("today")))
        self.workdays_checkboxgroup = CheckboxGroup(
            labels=["Sun", "Mon", "Tue", "Wed", "Thu", "Fri", "Sat"],
            active=[],
            inline=True,
            margin=(0, 15, 0, 0))
        self.workdays_apply_button = Button(label='Select Business Days',
                                            button_type='primary')
        self.product_selector_plotting = Select(
            title='Select Product to Display',
            value='v1',
            options=['v1', 'v2'])
        self.prediction_button = Button(
            label='Forecast Demand for Selected Product ID',
            button_type='primary')
        self.default_info_msg = 'This window will contain additional information,\nas you interact with the app.'
        self.info_paragraph = PreText(
            text='Details:\n{}'.format(self.default_info_msg))
        # self.text = TextInput(title='title', value='my sine wave')
        # self.offset = Slider(title='offset', value=0.0, start=-5.0, end=5.0, step=0.1)

        self.widgets = {
            'data_source_selector': self.data_source_selector,
            'file_input': self.file_input,
            'values_col_selector': self.values_col_selector,
            'product_id_col_selector': self.product_id_col_selector,
            'data_preview_paragraph': self.data_preview_paragraph,
            'data_table': self.data_table,
            'product_selector': self.product_selector_plotting,
            'demand_plot': self.demand_plot,
            'date_col_selector': self.date_col_selector,
            'last_date_picker': self.last_date_picker,
            'workdays_checkboxgroup': self.workdays_checkboxgroup,
            'workdays_apply_button': self.workdays_apply_button,
            'prediction_button': self.prediction_button,
            #'': self.,
        }

        self.values_colname = None
        self.product_id_colname = None
        self.date_colname = None
        self.product_ids = []

########## WIDGETS VISIBILITY CONTROLS ##########

    def _change_widgets_visibility(self, names, names_show_or_hide='show'):
        displaying = True if names_show_or_hide == 'show' else False
        for widget_name in self.widgets:
            if widget_name in names:
                self.widgets[widget_name].visible = displaying
            else:
                self.widgets[widget_name].visible = not displaying

    def display_all_widgets_except(self, widgets=[]):
        self._change_widgets_visibility(widgets, 'hide')

    def hide_all_widgets_except(self, widgets=[]):
        self._change_widgets_visibility(widgets, 'show')

########## LOGIC ##########

    def set_widget_to_default_value(self,
                                    widget_names,
                                    default_val='Not Selected'):
        for widget_name in widget_names:
            self.widgets[widget_name].value = default_val

    def prepare_values_col_selection(self):
        self.values_col_selector.options = ['Not Selected'
                                            ] + self.input_df.columns.tolist()

    def get_additional_cols_to_show(self):
        return ['file_input'
                ] if self.data_source_selector.value == 'Upload Data' else []

    def update_details_msg(self, msg):
        self.info_paragraph.text = "Details:\n{}".format(msg)

    def preview_input_df(self):
        # https://stackoverflow.com/questions/40942168/how-to-create-a-bokeh-datatable-datetime-formatter
        columns = [
            TableColumn(field=Ci, title=Ci, width=DATATABLE_PREVIEW_COL_WIDTH)
            for Ci in self.input_df.columns
        ]
        self.data_table.update(columns=columns)
        self.data_table.update(
            source=ColumnDataSource(self.input_df.head(DF_NUM_PREVIEW_ROWS)))
        self.data_table.visible = True
        self.data_preview_paragraph.visible = True

    def upload_fit_data(self, attr, old, new):
        print('fit data upload succeeded')
        self.update_details_msg(msg='Step 1/5: Uploading data')
        base64_message = self.file_input.value
        base64_bytes = base64_message.encode('ascii')
        message_bytes = base64.b64decode(base64_bytes)
        message = message_bytes.decode('ascii')
        self.input_df = pd.read_csv(StringIO(message), sep=',')
        self.update_details_msg(
            msg='Step 1/5: Data has been successfully uploaded!')
        print('Input DF shape: {}'.format(self.input_df.shape))
        self.prepare_values_col_selection()
        self.hide_all_widgets_except(
            ['data_source_selector', 'file_input', 'values_col_selector'])
        self.preview_input_df()

    def replace_selector_options(self, selector, old_value, new_options):
        selector.options = [old_value] + new_options
        selector.value = new_options[0]
        selector.options = new_options

    def date_col_integrity(self, date_colname):
        if not isinstance(self.input_df[date_colname][0], str):
            self.input_df[date_colname] = self.input_df[date_colname].astype(
                str)
        if '-' in self.input_df[date_colname][0]:
            sep = '-'
        elif '/' in self.input_df[date_colname][0]:
            sep = '/'
        else:
            return 'no separator found'
        date_parts = self.input_df[date_colname].apply(lambda x: x.split(sep))
        if (date_parts.apply(lambda x: len(x)) == 3).all():
            try:
                self.input_df[date_colname] = pd.to_datetime(
                    self.input_df[date_colname])
                return 'ok'
            except:
                return 'error converting to datetime'
        else:
            return 'not all dates have exactly 3 components'

    def display_preview_plot(self):
        self.replace_selector_options(self.product_selector_plotting, 'v1',
                                      self.product_ids)
        self.product_selector_plotting.visible = True
        self.prediction_button.visible = True
        self.demand_plot.renderers.remove(self.line1)
        self.plot_data_source = None
        self.plot_data_source = ColumnDataSource(data=self.input_df[
            self.input_df[self.product_id_colname] == self.product_ids[0]])
        self.line1 = self.demand_plot.line(x=self.date_colname,
                                           y=self.values_colname,
                                           source=self.plot_data_source,
                                           line_color='blue',
                                           name='line1')
        self.update_plot(None, None, self.product_ids[0])
        self.demand_plot.visible = True

    def generate_dates(self, end_date: datetime.datetime, work_days: list,
                       num_periods: int):
        work_days = ' '.join(work_days)  # 'Sun Mon Tue Wed Fri'
        freq = pd.offsets.CustomBusinessDay(weekmask=work_days)
        return pd.date_range(end=end_date, periods=num_periods, freq=freq)

    def clean_df(self):
        """
        Modifies self.input_df:
        1) Removing duplicates based on [self.date_colname, self.product_id_colname]
        2) Sorting based on self.date_colname
        :return: void
        """
        self.input_df = self.input_df[~self.input_df.duplicated(
            subset=[self.date_colname, self.product_id_colname], keep='first')]
        self.input_df.sort_values(by=self.date_colname, inplace=True)
        print('===RESULTED INPUT_DF SHAPE AFTER CLEANING: ',
              self.input_df.shape)

########## WIDGETS ON_CHANGE METHODS ##########

    def select_data_source(self, attrname, old_val, new_val):
        self.set_widget_to_default_value([
            'values_col_selector', 'product_id_col_selector',
            'date_col_selector'
        ])
        if new_val == 'Upload Data':
            self.update_details_msg(
                msg=
                'Step 1/5: Please upload data in one of the\nfollowing formats: .CSV or .XLSX'
            )
            self.hide_all_widgets_except(
                ['data_source_selector', 'file_input'])
        elif new_val == 'Use Example Data':
            self.update_details_msg(
                msg=
                'Step 1/5: Using a sample toy data. You can use it\nto test the functionality of this app.'
            )
            self.input_df = pd.read_csv('default_table.csv')
            self.prepare_values_col_selection()
            self.preview_input_df()
            self.hide_all_widgets_except([
                'data_source_selector', 'values_col_selector',
                'data_preview_paragraph', 'data_table'
            ])
        else:  # Not Selected
            self.update_details_msg(msg=self.default_info_msg)
            self.hide_all_widgets_except(['data_source_selector'])

    def select_values_colname(self, attrname, old_val, new_val):
        self.update_details_msg(
            msg=
            'Step 2/5: Please select a column that contains\nthe demand values. Note, that all the values in\nthis column should be numerical.'
        )
        self.set_widget_to_default_value(
            ['product_id_col_selector', 'date_col_selector'])
        self.hide_all_widgets_except([
            'data_source_selector', 'values_col_selector',
            'data_preview_paragraph', 'data_table'
        ] + self.get_additional_cols_to_show())
        if new_val == 'Not Selected':
            pass
        else:
            self.values_colname = new_val
            try:
                self.input_df[self.values_colname] = self.input_df[
                    self.values_colname].astype(float)
                available_cols = set(self.input_df.columns)
                available_cols.remove(self.values_colname)
                if self.date_colname in available_cols:
                    available_cols.remove(self.date_colname)
                self.product_id_col_selector.options = [
                    'Not Selected'
                ] + list(available_cols)
                self.product_id_col_selector.visible = True
            except:
                self.update_details_msg(
                    msg=
                    'WARNING! Step 2/5: Not all the values\nin selected column are numerical!'
                )

    def select_product_id_colname(self, attrname, old_val, new_val):
        self.update_details_msg(
            msg=
            "Step 3/5: Please select a column that contains products' identifiers."
        )
        self.set_widget_to_default_value(['date_col_selector'])
        self.hide_all_widgets_except([
            'data_source_selector', 'values_col_selector',
            'data_preview_paragraph', 'data_table', 'product_id_col_selector'
        ] + self.get_additional_cols_to_show())
        if new_val == 'Not Selected':
            pass
        else:
            self.product_id_colname = new_val
            self.product_ids = self.input_df[
                self.product_id_colname].unique().astype(str).tolist()
            available_cols = set(self.input_df.columns)
            for colname in [self.values_colname, self.product_id_colname]:
                available_cols.remove(colname)
            if self.date_colname in available_cols:
                available_cols.remove(self.date_colname)
            self.date_col_selector.options = ['Not Selected'
                                              ] + list(available_cols)
            self.date_col_selector.visible = True
            self.last_date_picker.visible = True
            self.workdays_checkboxgroup.visible = True
            self.workdays_apply_button.visible = True

    def select_date_column(self, attrname, old_val, new_val):
        self.update_details_msg(
            msg="Step 4/5: If there is a date column, please select it's name.\n"
            "Note: Dates should be in one of the following formats:\n"
            "yyyy-mm-dd OR mm-dd-yyyy OR yyyy/mm/dd OR mm/dd/yyyy\n"
            "If there is no such column, use 'Not Selected' option.")
        self.hide_all_widgets_except([
            'data_source_selector', 'values_col_selector',
            'data_preview_paragraph', 'data_table', 'product_id_col_selector',
            'date_col_selector'
        ] + self.get_additional_cols_to_show())
        if new_val == 'Not Selected':
            self.last_date_picker.visible = True
            self.workdays_checkboxgroup.visible = True
            self.workdays_apply_button.visible = True
        else:
            self.date_colname = new_val
            date_col_integrity_status = self.date_col_integrity(
                self.date_colname)
            if date_col_integrity_status == 'ok':
                self.clean_df()
                self.display_preview_plot()
            else:
                print('date_col_integrity_status: ', date_col_integrity_status)
                self.update_details_msg(
                    msg=
                    "ERROR: selected date column doesn't satisfy specified requirements:\n"
                    "Dates should be in one of the following formats:\n"
                    "yyyy-mm-dd OR mm-dd-yyyy OR yyyy/mm/dd OR mm/dd/yyyy\n"
                    "If there is no such column, use 'Not Selected' option.")

    def select_last_date(self, attrname, old_val, new_val):
        self.update_details_msg(
            msg="Alright, dates will be automatically generated for you!\n"
            "Select days when your business works.")
        self.workdays_checkboxgroup.visible = True
        self.workdays_apply_button.visible = True

    def workdays_button_pressed(self, new):
        if len(self.workdays_checkboxgroup.active) == 0:
            self.update_details_msg(
                msg="Please select at least one business day.")
        else:
            self.update_details_msg(msg="Generating dates.")
            if 'generated_dates' in self.input_df.columns:
                self.update_details_msg(
                    msg="Please rename the generated_dates column in you table."
                )
            else:
                self.date_colname = 'generated_date'
                self.input_df[self.date_colname] = ''
                for product_id in self.product_ids:
                    inds = self.input_df[self.product_id_colname] == product_id
                    self.input_df.loc[
                        inds, self.date_colname] = self.generate_dates(
                            end_date=self.last_date_picker.value,
                            work_days=np.array(
                                self.workdays_checkboxgroup.labels)[
                                    self.workdays_checkboxgroup.active],
                            num_periods=inds.sum())
                self.input_df[self.date_colname] = pd.to_datetime(
                    self.input_df[self.date_colname])
                self.clean_df()
                self.display_preview_plot()
                #self.preview_input_df() # https://stackoverflow.com/questions/40942168/how-to-create-a-bokeh-datatable-datetime-formatter

    def prediction_button_pressed(self, new):
        train_dataset = pd.DataFrame()
        print('Preparing forecast for product: ',
              self.product_selector_plotting.value)
        inds = self.input_df[
            self.product_id_colname] == self.product_selector_plotting.value
        train_dataset['ds'] = self.input_df.loc[inds, self.date_colname]
        train_dataset['y'] = self.input_df.loc[inds, self.values_colname]
        # train_dataset = train_dataset[train_dataset.duplicated(subset=['ds'],keep='first')]
        #train_dataset.sort_values(by=self.date_colname, inplace=True)
        print('Train Dataset shape: ', train_dataset.shape)
        for q in self.make_predictions(train_dataset):
            if q[0] == 'msg':
                print('Message: ', q[1])
            else:
                self.forecasted_df = q[1]
                self.forecasted_df.columns = ['ds', 'y']
                print('Done; shape: ', self.forecasted_df.shape)
                #self.demand_plot.line(x='ds', y='yhat', source=ColumnDataSource(data=self.forecasted_df, name='line2'))
                #print(self.forecasted_df.tail(30))

                #combined_dataset = train_dataset.append(self.forecasted_df.tail(30), ignore_index=True)
                d = {
                    'ds':
                    train_dataset['ds'].append(
                        self.forecasted_df.tail(30)['ds']),
                    'y':
                    train_dataset['y'].append(
                        self.forecasted_df.tail(30)['y'])
                }
                combined_dataset = pd.DataFrame(d)

                try:
                    while len(self.demand_plot.legend[0].items) > 0:
                        self.demand_plot.legend[0].items.pop()
                except:
                    print(
                        'FAIL: popping legends in prediction_button_pressed()')

                self.demand_plot.renderers.remove(self.line1)
                try:
                    self.demand_plot.renderers.remove(self.line2)
                except:
                    pass

                self.plot_data_source = None
                self.plot_data_source = ColumnDataSource(data=combined_dataset)
                self.line1 = self.demand_plot.line(x=train_dataset['ds'],
                                                   y=train_dataset['y'],
                                                   line_color='blue',
                                                   name='line1',
                                                   legend_label='Historical')
                self.line2 = self.demand_plot.line(
                    x=train_dataset['ds'].tail(1).append(
                        self.forecasted_df['ds'].tail(30)),
                    y=train_dataset['y'].tail(1).append(
                        self.forecasted_df['y'].tail(30)),
                    line_color='red',
                    name='line2',
                    legend_label='Forecast')
                #print('QQQ ', self.demand_plot.select(name="line2"))
                self.demand_plot.legend.location = "top_left"

                self.demand_plot.x_range.start = combined_dataset['ds'].min()
                self.demand_plot.x_range.end = combined_dataset['ds'].max()
                self.demand_plot.y_range.start = combined_dataset['y'].min()
                self.demand_plot.y_range.end = combined_dataset['y'].max()

                self.demand_plot.visible = True

########## OTHER ##########

    def dates_diff_count(self, df, product_name):
        days_diffs = (
            df[1:][self.date_colname].values -
            df[:-1][self.date_colname].values) / 1000000000 / 60 / 60 / 24
        unique_diffs, diffs_counts = np.unique(days_diffs, return_counts=True)
        msg = 'Product: {}:\n# Days Delta ; Count\n'.format(product_name)
        for value, count in zip(unique_diffs, diffs_counts):
            msg += '{:10} ; {}\n'.format(value, count)
        msg += 'If there is more than one unique value\nit can make forecast less accurate'
        self.update_details_msg(msg=msg)

    # https://facebook.github.io/prophet/docs/non-daily_data.html
    def make_predictions(self, df, days_ahead=30):
        yield ['msg', 'training model']
        prophet = Prophet(weekly_seasonality=False, daily_seasonality=False)
        prophet.fit(df)
        yield ['msg', 'making predictions']
        future = prophet.make_future_dataframe(periods=days_ahead)
        forecast = prophet.predict(future)
        yield ['results', forecast[['ds', 'yhat']]]

    def update_plot(self, attrname, old, new):
        try:
            while len(self.demand_plot.legend[0].items) > 0:
                self.demand_plot.legend[0].items.pop()
        except:
            print('FAIL: popping legends in update_plot()')
        try:
            self.demand_plot.renderers.remove(self.line2)
        except:
            pass

        sub_df = self.input_df[self.input_df[self.product_id_colname] == new]
        self.dates_diff_count(sub_df, new)
        self.demand_plot.renderers.remove(self.line1)
        self.plot_data_source = None
        self.plot_data_source = ColumnDataSource(data=sub_df)
        self.line1 = self.demand_plot.line(x=self.date_colname,
                                           y=self.values_colname,
                                           source=self.plot_data_source,
                                           line_color='blue',
                                           legend_label='Historical',
                                           name='line1')
        self.demand_plot.legend.location = "top_left"
        self.demand_plot.x_range.start = sub_df[self.date_colname].min()
        self.demand_plot.x_range.end = sub_df[self.date_colname].max()
        self.demand_plot.y_range.start = sub_df[self.values_colname].min()
        self.demand_plot.y_range.end = sub_df[self.values_colname].max()


########## MAIN ##########

    def display(self):
        self.file_input.on_change('value', self.upload_fit_data)
        self.plot = figure(plot_height=400,
                           plot_width=400,
                           title='my sine wave',
                           tools='crosshair,pan,reset,save,wheel_zoom')

        # Set up layouts and add to document
        self.inputs = column(self.data_source_selector, self.file_input,
                             self.values_col_selector,
                             self.product_id_col_selector,
                             self.date_col_selector, self.last_date_picker,
                             self.workdays_checkboxgroup,
                             self.workdays_apply_button)

        #self.data_source_selector.visible = True
        self.hide_all_widgets_except(['data_source_selector'])
        self.data_source_selector.on_change('value', self.select_data_source)
        self.values_col_selector.on_change('value', self.select_values_colname)
        self.product_id_col_selector.on_change('value',
                                               self.select_product_id_colname)
        self.product_selector_plotting.on_change('value', self.update_plot)
        self.date_col_selector.on_change('value', self.select_date_column)
        self.last_date_picker.on_change('value', self.select_last_date)
        self.workdays_apply_button.on_click(self.workdays_button_pressed)
        self.prediction_button.on_click(self.prediction_button_pressed)

        #self.col_left = self.inputs

        columns = [
            TableColumn(field=Ci, title=Ci, width=DATATABLE_PREVIEW_COL_WIDTH)
            for Ci in self.input_df.columns
        ]
        self.data_table.columns = columns
        self.data_table.source = ColumnDataSource(
            self.input_df.head(DF_NUM_PREVIEW_ROWS))

        self.col_middle = column(self.data_preview_paragraph, self.data_table)
        #self.col_info = column()

        #self.col_left.width = 300
        #self.col_right.max_width = 500
        #self.col_right.sizing_mode = 'scale_width'

        #self.row_data_input = row(self.col_left, self.col_right, self.info_paragraph)
        #self.row_data_input.sizing_mode = 'scale_width'

        #self.row_demand_plot = row(self.product_selector_plotting)#, self.demand_plot)

        #self.layout = column(self.row_data_input, self.row_demand_plot)

        self.layout = column(
            row(
                column(
                    self.data_source_selector,
                    self.file_input,
                    self.values_col_selector,
                    self.product_id_col_selector,
                ), column(self.data_preview_paragraph, self.data_table),
                self.info_paragraph),
            row(
                column(self.date_col_selector, self.last_date_picker,
                       self.workdays_checkboxgroup, self.workdays_apply_button,
                       self.product_selector_plotting, self.prediction_button),
                self.demand_plot))

        curdoc().add_root(self.layout)
        curdoc().title = 'Demand Forecasting'
Пример #27
0
def spectrum_slice_app(doc):
    def load_wav_cb(attr, old, new):
        '''Handle selection of audio file to be loaded.'''
        global wavname
        base_url, fname = os.path.split(new)
        wavname = get_cached_fname(fname, base_url, tempdir)
        if not wavname.endswith('.wav'):
            return
        update_snd()
        playvisbtn.channels = channels
        playvisbtn.visible = True
        playselbtn.channels = channels
        playselbtn.visible = True
        playvisbtn.fs = snd.sampling_frequency
        playvisbtn.start = snd.start_time
        playvisbtn.end = snd.end_time
        playselbtn.fs = snd.sampling_frequency
        playselbtn.start = 0.0
        playselbtn.end = 0.0
        ch0.visible = True
        update_sgram()
        update_spslice(t=None)

    def file_input_cb(attr, old, new):
        '''Handle audio file upload.'''
        with NamedTemporaryFile() as tempfile:
            tempfile.write(b64decode(new))
            tempsnd = parselmouth.Sound(tempfile.name)
        ds_snd = tempsnd.resample(params['downsample_rate'], 50)
        cachefile = os.path.join(tempdir, file_input.filename)
        ds_snd.save(cachefile, parselmouth.SoundFileFormat.WAV)
        options = fselect.options.copy()
        options += [(cachefile, file_input.filename)]
        fselect.options = options
        fselect.value = fselect.options[-1][0]

    def update_snd():
        '''Update the sound (waveform and audio button).'''
        global snd
        snd = parselmouth.Sound(wavname)
        if snd.n_channels > 1:
            snd = snd.convert_to_mono()
        if filter_sel.value not in ('no filter', 'no filter (clear)'
                                    ) and spselbox.right is not None:
            if filter_sel.value.startswith('stopband'):
                func = 'Filter (stop Hann band)...'
            if filter_sel.value.startswith('passband'):
                func = 'Filter (pass Hann band)...'
            snd = parselmouth.praat.call(snd, func, spselbox.left,
                                         spselbox.right, 100.0)
        source.data = dict(
            seconds=snd.ts().astype(np.float32),
            ch0=snd.values[0, :].astype(np.float32),
        )

    def update_sgram():
        '''Update spectrogram based on current values.'''
        if filter_sel.value == 'no filter (clear)':
            sgselbox.bottom = None
            sgselbox.top = None
            sgselbox.visible = False
        else:
            sgselbox.visible = True
        sgrams[0] = snd2specgram(snd, winsize=winsize_slider.value * 10**-3)
        specsource.data = dict(sgram0=[sgrams[0].values.astype(np.float32)])
        spec0img.glyph.dw = sgrams[0].x_grid().max()
        spec0img.glyph.dh = sgrams[0].y_grid().max()
        spec0cmap.low = _low_thresh()
        spec0.visible = True

    def update_spslice(t=None):
        '''Update spslice plot with spectrum slice at time t.'''
        if t is not None:
            slidx = np.round(
                parselmouth.praat.call(sgrams[0],
                                       'Get frame number from time...',
                                       t)).astype(int)
            spslice = sgrams[0].values[:, slidx]
            spdata = dict(freq=np.arange(sgrams[0].values.shape[0]) *
                          sgrams[0].dy,
                          power=spslice)
            spselbox.visible = True
        else:
            spdata = dict(freq=np.array([]), power=np.array([]))
            spec0_fq_marker.visible = False
            spslice0_fq_marker.visible = False
            spselbox.visible = False
        if filter_sel.value == 'no filter (clear)':
            spselbox.left = None
            spselbox.right = None
            spselbox.visible = False
        spslice_source.data = spdata
        spslice0.x_range = Range1d(0.0, sgrams[0].get_highest_y())
        spslice0.y_range = Range1d(0.0, sgrams[0].get_maximum())
        thresh_box.top = _low_thresh()
        try:
            fqidx = np.abs(spslice_source.data['freq'] -
                           fq_marker_source.data['freq'][0]).argmin()
            fq_marker_source.data['power'] = [
                spslice_source.data['power'][fqidx]
            ]
        except ValueError:
            pass  # Not set yet

    def cursor_cb(e):
        '''Handle cursor mouse click that creates the spectrum slice.'''
        cursor.location = e.x
        update_spslice(t=e.x)
        idx = np.abs(spslice_source.data['freq'] - e.y).argmin()
        fq_marker_source.data = dict(freq=[e.y],
                                     power=[spslice_source.data['power'][idx]],
                                     time=[e.x])
        params['spslice_lastx'] = e.y
        spec0_fq_marker.visible = True
        spslice0_fq_marker.visible = True

    def spslice_move_cb(e):
        '''Handle a MouseMove event on spectrum slice crosshair tool.'''
        try:
            if params[
                    'spslice_lastx'] != e.x and e.x >= 0 and e.x <= spslice_source.data[
                        'freq'][-1]:
                params['spslice_lastx'] = e.x
                idx = np.abs(spslice_source.data['freq'] - e.x).argmin()
                fq_marker_source.data['freq'] = [
                    spslice_source.data['freq'][idx]
                ]
                fq_marker_source.data['power'] = [
                    spslice_source.data['power'][idx]
                ]
        except IndexError:  # data not loaded yet
            pass

    def x_range_cb(attr, old, new):
        '''Handle change of x range in waveform/spectrogram.'''
        if attr == 'start':
            playvisbtn.start = new
        elif attr == 'end':
            playvisbtn.end = new

    def selection_cb(e):
        '''Handle data range selection event.'''
        playselbtn.start = e.geometry['x0']
        playselbtn.end = e.geometry['x1']
        selbox.left = e.geometry['x0']
        selbox.right = e.geometry['x1']
        selbox.visible = True

    def low_thresh_cb(attr, old, new):
        '''Handle change in threshold slider to fade out low spectrogram values.'''
        params['low_thresh_power'] = new
        lt = _low_thresh()
        spec0cmap.low = lt
        thresh_box.top = lt

    def _low_thresh():
        return sgrams[0].values.min() \
               + sgrams[0].values.std()**params['low_thresh_power']

    def winsize_cb(attr, old, new):
        '''Handle change in winsize slider to change spectrogram analysis window.'''
        params['window_size'] = new
        update_sgram()
        if cursor.location is not None:
            update_spslice(t=cursor.location)
            idx = np.abs(spslice_source.data['freq'] -
                         params['spslice_lastx']).argmin()
            fq_marker_source.data = dict(
                freq=[spslice_source.data['freq'][idx]],
                power=[spslice_source.data['power'][idx]],
                time=[cursor.location])

    def filter_sel_cb(e):
        '''Handle change of filter range.'''
        lowfq = e.geometry['x0']
        highfq = e.geometry['x1']
        sgselbox.bottom = lowfq
        sgselbox.top = highfq
        spselbox.left = lowfq
        spselbox.right = highfq
        range_text = f' ({lowfq:.0f}-{highfq:.0f} Hz)'
        # Force assignment of new options so that Bokeh detects the values have changed
        # and synchronizes the JS.
        options = filter_sel.options.copy()
        for idx, opt in enumerate(options):
            if 'stopband' in opt:
                options[idx] = f'stopband {range_text}'
                if 'stopband' in filter_sel.value:
                    filter_sel.value = options[idx]
            if 'passband' in opt:
                options[idx] = f'passband {range_text}'
                if 'passband' in filter_sel.value:
                    filter_sel.value = options[idx]
        filter_sel.options = options
        update_snd()
        update_sgram()
        update_spslice(t=cursor.location)

    def filter_type_cb(attr, old, new):
        '''Handle change in filter type.'''
        if 'clear' in new:
            # Force assignment of new options so that Bokeh detects the values have changed
            # and synchronizes the JS.
            options = filter_sel.options.copy()
            for idx, opt in enumerate(options):
                if 'passband' in opt:
                    options[idx] = 'passband'
                    if 'passband' in filter_sel.value:
                        filter_sel.value = 'passband'
                if 'stopband' in opt:
                    options[idx] = 'stopband'
                    if 'stopband' in filter_sel.value:
                        filter_sel.value = 'stopband'
            filter_sel.options = options
        update_snd()
        update_sgram()
        update_spslice(t=cursor.location)

    manifest_text = requests.get(resource_url + manifest_name).text
    manifest = yaml.safe_load(manifest_text)[manifest_key]
    options = [('', 'Choose an audio file to display')] + [
        (resource_url + opt['fname'], opt['label']) for opt in manifest
    ]
    fselect = Select(options=options, value='')
    fselect.on_change('value', load_wav_cb)
    file_input = FileInput(accept=".wav")
    fselect_row = row(fselect, file_input)
    file_input.on_change('value', file_input_cb)
    source = ColumnDataSource(data=dict(seconds=[], ch0=[]))
    channels = ['ch0']

    playvisbtn = AudioButton(label='Play visible signal',
                             source=source,
                             channels=channels,
                             width=120,
                             visible=False)
    playselbtn = AudioButton(label='Play selected signal',
                             source=source,
                             channels=channels,
                             width=120,
                             visible=False)

    # Instantiate and share specific select/zoom tools so that
    # highlighting is synchronized on all plots.
    boxsel = BoxSelectTool(dimensions='width')
    spboxsel = BoxSelectTool(dimensions='width')
    boxzoom = BoxZoomTool(dimensions='width')
    zoomin = ZoomInTool(dimensions='width')
    zoomout = ZoomOutTool(dimensions='width')
    crosshair = CrosshairTool(dimensions='height')
    shared_tools = [
        'xpan', boxzoom, boxsel, crosshair, 'undo', 'redo', zoomin, zoomout,
        'reset'
    ]

    figargs = dict(tools=shared_tools, )
    cursor = Span(dimension='height',
                  line_color='red',
                  line_dash='dashed',
                  line_width=1)
    ch0 = figure(name='ch0', tooltips=[("time", "$x{0.0000}")], **figargs)
    ch0.line(x='seconds', y='ch0', source=source, nonselection_line_alpha=0.6)
    # Link pan, zoom events for plots with x_range.
    ch0.x_range.on_change('start', x_range_cb)
    ch0.x_range.on_change('end', x_range_cb)
    ch0.on_event(SelectionGeometry, selection_cb)
    ch0.on_event(Tap, cursor_cb)
    ch0.add_layout(cursor)
    low_thresh = 0.0
    sgrams = [np.ones((1, 1))]
    specsource = ColumnDataSource(data=dict(sgram0=[sgrams[0]]))
    fq_marker_source = ColumnDataSource(
        data=dict(freq=[0.0], power=[0.0], time=[0.0]))
    spec0 = figure(
        name='spec0',
        x_range=ch0.x_range,  # Keep times synchronized
        tooltips=[("time", "$x{0.0000}"), ("freq", "$y{0.0000}"),
                  ("value", "@sgram0{0.000000}")],
        **figargs)
    spec0.add_layout(cursor)
    spec0_fq_marker = spec0.circle(x='time',
                                   y='freq',
                                   source=fq_marker_source,
                                   size=6,
                                   line_color='red',
                                   fill_color='red',
                                   visible=False)
    spec0.x_range.range_padding = spec0.y_range.range_padding = 0
    spec0cmap = LogColorMapper(palette=r_Greys256,
                               low_color=params['low_thresh_color'])
    low_thresh_slider = Slider(start=1.0,
                               end=12.0,
                               step=0.125,
                               value=params['low_thresh_power'],
                               title='Low threshold')
    winsize_slider = Slider(start=5.0,
                            end=40.0,
                            step=5.0,
                            value=params['window_size'],
                            title='Analysis window (ms)')
    filter_sel = Select(
        options=['no filter (clear)', 'no filter', 'passband', 'stopband'],
        value='no filter (clear)')
    spec0img = spec0.image(image='sgram0',
                           x=0,
                           y=0,
                           color_mapper=spec0cmap,
                           level='image',
                           source=specsource)
    spec0.grid.grid_line_width = 0.0
    low_thresh_slider.on_change('value', low_thresh_cb)
    winsize_slider.on_change('value', winsize_cb)
    filter_sel.on_change('value', filter_type_cb)
    selbox = BoxAnnotation(name='selbox',
                           left=None,
                           right=None,
                           fill_color='green',
                           fill_alpha=0.1,
                           line_color='green',
                           line_width=1.5,
                           line_dash='dashed',
                           visible=False)
    sgselbox = BoxAnnotation(name='sgselbox',
                             top=None,
                             bottom=None,
                             fill_color='red',
                             fill_alpha=0.1,
                             line_color='red',
                             line_width=1.5,
                             line_dash='dashed',
                             visible=False)
    ch0.add_layout(selbox)
    spec0.add_layout(selbox)
    spec0.add_layout(sgselbox)
    spec0.on_event(SelectionGeometry, selection_cb)
    spec0.on_event(Tap, cursor_cb)
    grid = gridplot([ch0, spec0],
                    ncols=1,
                    plot_height=200,
                    toolbar_location='left',
                    toolbar_options={'logo': None},
                    merge_tools=True)
    spslice_chtool = CrosshairTool(dimensions='height')
    spslice0 = figure(name='spslice0',
                      plot_width=400,
                      plot_height=250,
                      y_axis_type='log',
                      y_range=(10**-9, 1),
                      tools=[spboxsel, spslice_chtool],
                      toolbar_location='left')
    spslice0.toolbar.logo = None
    spslice_source = ColumnDataSource(
        data=dict(freq=np.array([]), power=np.array([])))
    spslice0.line(x='freq', y='power', source=spslice_source)
    spselbox = BoxAnnotation(name='spselbox',
                             left=None,
                             right=None,
                             fill_color='red',
                             fill_alpha=0.1,
                             line_color='red',
                             line_width=1.5,
                             line_dash='dashed',
                             visible=False)
    spslice0.add_layout(spselbox)
    spslice0.on_event(SelectionGeometry, filter_sel_cb)
    thresh_box = BoxAnnotation(fill_color=params['low_thresh_color'])
    spslice0.add_layout(thresh_box)
    spslice0.on_event(MouseMove, spslice_move_cb)
    spslice0_fq_marker = spslice0.circle(x='freq',
                                         y='power',
                                         source=fq_marker_source,
                                         size=6,
                                         line_color='red',
                                         fill_color='red',
                                         visible=False)
    num_fmtr = NumberFormatter(format='0.0000')
    det_num_fmtr = NumberFormatter(format='0.000000000')
    fq_marker_table = DataTable(source=fq_marker_source,
                                columns=[
                                    TableColumn(field="freq",
                                                title="Frequency",
                                                formatter=num_fmtr),
                                    TableColumn(field="power",
                                                title="Power",
                                                formatter=det_num_fmtr),
                                    TableColumn(field="time",
                                                title="Time",
                                                formatter=num_fmtr),
                                ],
                                width=300)
    control_col = column(row(playvisbtn, playselbtn), low_thresh_slider,
                         winsize_slider, filter_sel, fq_marker_table)
    grid2 = gridplot([spslice0, control_col], ncols=2)

    mainLayout = column(
        fselect_row,
        grid,  #low_thresh_slider, winsize_slider,
        grid2,
        name='mainLayout')
    doc.add_root(mainLayout)
    return doc
Пример #28
0
    def __init__(self):
        self.input_df = pd.DataFrame({
            'x': ['2010-01-01'] * DF_NUM_PREVIEW_ROWS,
            'y': [0] * DF_NUM_PREVIEW_ROWS
        })
        self.forecasted_df = None
        self.datefmt = DateFormatter(format='%m-%d-%Y')
        self.inputs = None
        self.x_range = [0, 10]
        self.demand_plot = figure(
            x_range=self.x_range,
            x_axis_type="datetime",
            tools=["pan", 'wheel_zoom'])  #,wheel_zoom,box_zoom,reset,resize")

        self.plot_data_source = ColumnDataSource(
            data=self.input_df)  #dict(x=[0], y=[0])
        self.line1 = self.demand_plot.line(x='x',
                                           y='y',
                                           source=self.plot_data_source,
                                           line_color='blue',
                                           name='line1')
        self.demand_plot.xaxis.formatter = DatetimeTickFormatter(
            days="%d %b %Y", hours="")
        self.demand_plot.axis.minor_tick_line_color = None
        self.demand_plot.xaxis[
            0].ticker.desired_num_ticks = 10  #num_minor_ticks = 0
        self.demand_plot.xaxis.major_label_orientation = radians(
            30)  # from math import radians

        # Set up widgets
        self.data_source_selector = Select(
            title='Step 1/5: Select Data',
            value='Not Selected',
            options=['Not Selected', 'Use Example Data', 'Upload Data'])
        self.file_input = FileInput(accept='.csv,.xlsx')
        self.data_table = DataTable(
            height=DATATABLE_PREVIEW_HEIGHT,
            width=DATATABLE_PREVIEW_WIDTH,
            fit_columns=False,
            index_position=None,
            margin=(0, 15, 0, 15),  #aspect_ratio=0.5,
            #default_size=50
        )
        self.data_preview_paragraph = Paragraph(text='Data Preview:',
                                                margin=(0, 15, 0, 15))
        self.values_col_selector = Select(
            title='Step 2/5: Select column with demand values',
            value='Not Selected',
            options=['Not Selected'])
        self.product_id_col_selector = Select(
            title='Step 3/5: Select column with product ID',
            value='Not Selected',
            options=['Not Selected'])
        self.date_col_selector = Select(title="Step 4/5: Select date column",
                                        value='Not Selected',
                                        options=['Not Selected'])
        self.last_date_picker = DatePicker(
            title='Select the date of last observation',
            max_date=datetime.datetime.date(pd.to_datetime("today")),
            value=datetime.datetime.date(pd.to_datetime("today")))
        self.workdays_checkboxgroup = CheckboxGroup(
            labels=["Sun", "Mon", "Tue", "Wed", "Thu", "Fri", "Sat"],
            active=[],
            inline=True,
            margin=(0, 15, 0, 0))
        self.workdays_apply_button = Button(label='Select Business Days',
                                            button_type='primary')
        self.product_selector_plotting = Select(
            title='Select Product to Display',
            value='v1',
            options=['v1', 'v2'])
        self.prediction_button = Button(
            label='Forecast Demand for Selected Product ID',
            button_type='primary')
        self.default_info_msg = 'This window will contain additional information,\nas you interact with the app.'
        self.info_paragraph = PreText(
            text='Details:\n{}'.format(self.default_info_msg))
        # self.text = TextInput(title='title', value='my sine wave')
        # self.offset = Slider(title='offset', value=0.0, start=-5.0, end=5.0, step=0.1)

        self.widgets = {
            'data_source_selector': self.data_source_selector,
            'file_input': self.file_input,
            'values_col_selector': self.values_col_selector,
            'product_id_col_selector': self.product_id_col_selector,
            'data_preview_paragraph': self.data_preview_paragraph,
            'data_table': self.data_table,
            'product_selector': self.product_selector_plotting,
            'demand_plot': self.demand_plot,
            'date_col_selector': self.date_col_selector,
            'last_date_picker': self.last_date_picker,
            'workdays_checkboxgroup': self.workdays_checkboxgroup,
            'workdays_apply_button': self.workdays_apply_button,
            'prediction_button': self.prediction_button,
            #'': self.,
        }

        self.values_colname = None
        self.product_id_colname = None
        self.date_colname = None
        self.product_ids = []
Пример #29
0
from qmp.potential import Potential
from qmp.potential import preset_potentials
from qmp.integrator.dyn_tools import create_gaussian
import numpy as np

from bokeh.layouts import column, grid, row
from bokeh.models import Button, FileInput, Select, ColumnDataSource
from bokeh.models import TextInput
from bokeh.models.widgets.buttons import Toggle, Dropdown
from bokeh.models.widgets.sliders import Slider, RangeSlider
from bokeh.models.widgets.groups import RadioButtonGroup
from bokeh.plotting import figure, curdoc
import pickle
import base64

file = FileInput()


def decode_data(data):
    decode = base64.b64decode(data)
    return pickle.loads(decode)


def load_data():
    data = decode_data(file.value)

    x = data['x']
    y = data['psi_t'][0, 0].real

    p = create_plot(x, y)
Пример #30
0
def create():
    det_data = []
    fit_params = {}
    js_data = ColumnDataSource(data=dict(content=["", ""], fname=["", ""]))

    def proposal_textinput_callback(_attr, _old, new):
        proposal = new.strip()
        for zebra_proposals_path in pyzebra.ZEBRA_PROPOSALS_PATHS:
            proposal_path = os.path.join(zebra_proposals_path, proposal)
            if os.path.isdir(proposal_path):
                # found it
                break
        else:
            raise ValueError(f"Can not find data for proposal '{proposal}'.")

        file_list = []
        for file in os.listdir(proposal_path):
            if file.endswith((".ccl", ".dat")):
                file_list.append((os.path.join(proposal_path, file), file))
        file_select.options = file_list
        file_open_button.disabled = False
        file_append_button.disabled = False

    proposal_textinput = TextInput(title="Proposal number:", width=210)
    proposal_textinput.on_change("value", proposal_textinput_callback)

    def _init_datatable():
        scan_list = [s["idx"] for s in det_data]
        file_list = []
        for scan in det_data:
            file_list.append(os.path.basename(scan["original_filename"]))

        scan_table_source.data.update(
            file=file_list,
            scan=scan_list,
            param=[None] * len(scan_list),
            fit=[0] * len(scan_list),
            export=[True] * len(scan_list),
        )
        scan_table_source.selected.indices = []
        scan_table_source.selected.indices = [0]

        scan_motor_select.options = det_data[0]["scan_motors"]
        scan_motor_select.value = det_data[0]["scan_motor"]
        param_select.value = "user defined"

    file_select = MultiSelect(title="Available .ccl/.dat files:",
                              width=210,
                              height=250)

    def file_open_button_callback():
        nonlocal det_data
        det_data = []
        for f_name in file_select.value:
            with open(f_name) as file:
                base, ext = os.path.splitext(f_name)
                if det_data:
                    append_data = pyzebra.parse_1D(file, ext)
                    pyzebra.normalize_dataset(append_data,
                                              monitor_spinner.value)
                    det_data.extend(append_data)
                else:
                    det_data = pyzebra.parse_1D(file, ext)
                    pyzebra.normalize_dataset(det_data, monitor_spinner.value)
                    js_data.data.update(
                        fname=[base + ".comm", base + ".incomm"])

        _init_datatable()
        append_upload_button.disabled = False

    file_open_button = Button(label="Open New", width=100, disabled=True)
    file_open_button.on_click(file_open_button_callback)

    def file_append_button_callback():
        for f_name in file_select.value:
            with open(f_name) as file:
                _, ext = os.path.splitext(f_name)
                append_data = pyzebra.parse_1D(file, ext)

            pyzebra.normalize_dataset(append_data, monitor_spinner.value)
            det_data.extend(append_data)

        _init_datatable()

    file_append_button = Button(label="Append", width=100, disabled=True)
    file_append_button.on_click(file_append_button_callback)

    def upload_button_callback(_attr, _old, new):
        nonlocal det_data
        det_data = []
        for f_str, f_name in zip(new, upload_button.filename):
            with io.StringIO(base64.b64decode(f_str).decode()) as file:
                base, ext = os.path.splitext(f_name)
                if det_data:
                    append_data = pyzebra.parse_1D(file, ext)
                    pyzebra.normalize_dataset(append_data,
                                              monitor_spinner.value)
                    det_data.extend(append_data)
                else:
                    det_data = pyzebra.parse_1D(file, ext)
                    pyzebra.normalize_dataset(det_data, monitor_spinner.value)
                    js_data.data.update(
                        fname=[base + ".comm", base + ".incomm"])

        _init_datatable()
        append_upload_button.disabled = False

    upload_div = Div(text="or upload new .ccl/.dat files:",
                     margin=(5, 5, 0, 5))
    upload_button = FileInput(accept=".ccl,.dat", multiple=True, width=200)
    upload_button.on_change("value", upload_button_callback)

    def append_upload_button_callback(_attr, _old, new):
        for f_str, f_name in zip(new, append_upload_button.filename):
            with io.StringIO(base64.b64decode(f_str).decode()) as file:
                _, ext = os.path.splitext(f_name)
                append_data = pyzebra.parse_1D(file, ext)

            pyzebra.normalize_dataset(append_data, monitor_spinner.value)
            det_data.extend(append_data)

        _init_datatable()

    append_upload_div = Div(text="append extra files:", margin=(5, 5, 0, 5))
    append_upload_button = FileInput(accept=".ccl,.dat",
                                     multiple=True,
                                     width=200,
                                     disabled=True)
    append_upload_button.on_change("value", append_upload_button_callback)

    def monitor_spinner_callback(_attr, _old, new):
        if det_data:
            pyzebra.normalize_dataset(det_data, new)
            _update_plot()

    monitor_spinner = Spinner(title="Monitor:",
                              mode="int",
                              value=100_000,
                              low=1,
                              width=145)
    monitor_spinner.on_change("value", monitor_spinner_callback)

    def scan_motor_select_callback(_attr, _old, new):
        if det_data:
            for scan in det_data:
                scan["scan_motor"] = new
            _update_plot()

    scan_motor_select = Select(title="Scan motor:", options=[], width=145)
    scan_motor_select.on_change("value", scan_motor_select_callback)

    def _update_table():
        fit_ok = [(1 if "fit" in scan else 0) for scan in det_data]
        scan_table_source.data.update(fit=fit_ok)

    def _update_plot():
        _update_single_scan_plot(_get_selected_scan())
        _update_overview()

    def _update_single_scan_plot(scan):
        scan_motor = scan["scan_motor"]

        y = scan["counts"]
        x = scan[scan_motor]

        plot.axis[0].axis_label = scan_motor
        plot_scatter_source.data.update(x=x,
                                        y=y,
                                        y_upper=y + np.sqrt(y),
                                        y_lower=y - np.sqrt(y))

        fit = scan.get("fit")
        if fit is not None:
            x_fit = np.linspace(x[0], x[-1], 100)
            plot_fit_source.data.update(x=x_fit, y=fit.eval(x=x_fit))

            x_bkg = []
            y_bkg = []
            xs_peak = []
            ys_peak = []
            comps = fit.eval_components(x=x_fit)
            for i, model in enumerate(fit_params):
                if "linear" in model:
                    x_bkg = x_fit
                    y_bkg = comps[f"f{i}_"]

                elif any(val in model
                         for val in ("gaussian", "voigt", "pvoigt")):
                    xs_peak.append(x_fit)
                    ys_peak.append(comps[f"f{i}_"])

            plot_bkg_source.data.update(x=x_bkg, y=y_bkg)
            plot_peak_source.data.update(xs=xs_peak, ys=ys_peak)

            fit_output_textinput.value = fit.fit_report()

        else:
            plot_fit_source.data.update(x=[], y=[])
            plot_bkg_source.data.update(x=[], y=[])
            plot_peak_source.data.update(xs=[], ys=[])
            fit_output_textinput.value = ""

    def _update_overview():
        xs = []
        ys = []
        param = []
        x = []
        y = []
        par = []
        for s, p in enumerate(scan_table_source.data["param"]):
            if p is not None:
                scan = det_data[s]
                scan_motor = scan["scan_motor"]
                xs.append(scan[scan_motor])
                x.extend(scan[scan_motor])
                ys.append(scan["counts"])
                y.extend([float(p)] * len(scan[scan_motor]))
                param.append(float(p))
                par.extend(scan["counts"])

        if det_data:
            scan_motor = det_data[0]["scan_motor"]
            ov_plot.axis[0].axis_label = scan_motor
            ov_param_plot.axis[0].axis_label = scan_motor

        ov_plot_mline_source.data.update(xs=xs,
                                         ys=ys,
                                         param=param,
                                         color=color_palette(len(xs)))

        if y:
            mapper["transform"].low = np.min([np.min(y) for y in ys])
            mapper["transform"].high = np.max([np.max(y) for y in ys])
        ov_param_plot_scatter_source.data.update(x=x, y=y, param=par)

        if y:
            interp_f = interpolate.interp2d(x, y, par)
            x1, x2 = min(x), max(x)
            y1, y2 = min(y), max(y)
            image = interp_f(
                np.linspace(x1, x2, ov_param_plot.inner_width // 10),
                np.linspace(y1, y2, ov_param_plot.inner_height // 10),
                assume_sorted=True,
            )
            ov_param_plot_image_source.data.update(image=[image],
                                                   x=[x1],
                                                   y=[y1],
                                                   dw=[x2 - x1],
                                                   dh=[y2 - y1])
        else:
            ov_param_plot_image_source.data.update(image=[],
                                                   x=[],
                                                   y=[],
                                                   dw=[],
                                                   dh=[])

    def _update_param_plot():
        x = []
        y = []
        fit_param = fit_param_select.value
        for s, p in zip(det_data, scan_table_source.data["param"]):
            if "fit" in s and fit_param:
                x.append(p)
                y.append(s["fit"].values[fit_param])

        param_plot_scatter_source.data.update(x=x, y=y)

    # Main plot
    plot = Plot(
        x_range=DataRange1d(),
        y_range=DataRange1d(only_visible=True),
        plot_height=450,
        plot_width=700,
    )

    plot.add_layout(LinearAxis(axis_label="Counts"), place="left")
    plot.add_layout(LinearAxis(axis_label="Scan motor"), place="below")

    plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    plot_scatter_source = ColumnDataSource(
        dict(x=[0], y=[0], y_upper=[0], y_lower=[0]))
    plot_scatter = plot.add_glyph(
        plot_scatter_source, Scatter(x="x", y="y", line_color="steelblue"))
    plot.add_layout(
        Whisker(source=plot_scatter_source,
                base="x",
                upper="y_upper",
                lower="y_lower"))

    plot_fit_source = ColumnDataSource(dict(x=[0], y=[0]))
    plot_fit = plot.add_glyph(plot_fit_source, Line(x="x", y="y"))

    plot_bkg_source = ColumnDataSource(dict(x=[0], y=[0]))
    plot_bkg = plot.add_glyph(
        plot_bkg_source,
        Line(x="x", y="y", line_color="green", line_dash="dashed"))

    plot_peak_source = ColumnDataSource(dict(xs=[[0]], ys=[[0]]))
    plot_peak = plot.add_glyph(
        plot_peak_source,
        MultiLine(xs="xs", ys="ys", line_color="red", line_dash="dashed"))

    fit_from_span = Span(location=None, dimension="height", line_dash="dashed")
    plot.add_layout(fit_from_span)

    fit_to_span = Span(location=None, dimension="height", line_dash="dashed")
    plot.add_layout(fit_to_span)

    plot.add_layout(
        Legend(
            items=[
                ("data", [plot_scatter]),
                ("best fit", [plot_fit]),
                ("peak", [plot_peak]),
                ("linear", [plot_bkg]),
            ],
            location="top_left",
            click_policy="hide",
        ))

    plot.add_tools(PanTool(), WheelZoomTool(), ResetTool())
    plot.toolbar.logo = None

    # Overview multilines plot
    ov_plot = Plot(x_range=DataRange1d(),
                   y_range=DataRange1d(),
                   plot_height=450,
                   plot_width=700)

    ov_plot.add_layout(LinearAxis(axis_label="Counts"), place="left")
    ov_plot.add_layout(LinearAxis(axis_label="Scan motor"), place="below")

    ov_plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    ov_plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    ov_plot_mline_source = ColumnDataSource(
        dict(xs=[], ys=[], param=[], color=[]))
    ov_plot.add_glyph(ov_plot_mline_source,
                      MultiLine(xs="xs", ys="ys", line_color="color"))

    hover_tool = HoverTool(tooltips=[("param", "@param")])
    ov_plot.add_tools(PanTool(), WheelZoomTool(), hover_tool, ResetTool())

    ov_plot.add_tools(PanTool(), WheelZoomTool(), ResetTool())
    ov_plot.toolbar.logo = None

    # Overview perams plot
    ov_param_plot = Plot(x_range=DataRange1d(),
                         y_range=DataRange1d(),
                         plot_height=450,
                         plot_width=700)

    ov_param_plot.add_layout(LinearAxis(axis_label="Param"), place="left")
    ov_param_plot.add_layout(LinearAxis(axis_label="Scan motor"),
                             place="below")

    ov_param_plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    ov_param_plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    ov_param_plot_image_source = ColumnDataSource(
        dict(image=[], x=[], y=[], dw=[], dh=[]))
    ov_param_plot.add_glyph(
        ov_param_plot_image_source,
        Image(image="image", x="x", y="y", dw="dw", dh="dh"))

    ov_param_plot_scatter_source = ColumnDataSource(dict(x=[], y=[], param=[]))
    mapper = linear_cmap(field_name="param", palette=Turbo256, low=0, high=50)
    ov_param_plot.add_glyph(
        ov_param_plot_scatter_source,
        Scatter(x="x", y="y", line_color=mapper, fill_color=mapper, size=10),
    )

    ov_param_plot.add_tools(PanTool(), WheelZoomTool(), ResetTool())
    ov_param_plot.toolbar.logo = None

    # Parameter plot
    param_plot = Plot(x_range=DataRange1d(),
                      y_range=DataRange1d(),
                      plot_height=400,
                      plot_width=700)

    param_plot.add_layout(LinearAxis(axis_label="Fit parameter"), place="left")
    param_plot.add_layout(LinearAxis(axis_label="Parameter"), place="below")

    param_plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    param_plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    param_plot_scatter_source = ColumnDataSource(dict(x=[], y=[]))
    param_plot.add_glyph(param_plot_scatter_source, Scatter(x="x", y="y"))

    param_plot.add_tools(PanTool(), WheelZoomTool(), ResetTool())
    param_plot.toolbar.logo = None

    def fit_param_select_callback(_attr, _old, _new):
        _update_param_plot()

    fit_param_select = Select(title="Fit parameter", options=[], width=145)
    fit_param_select.on_change("value", fit_param_select_callback)

    # Plot tabs
    plots = Tabs(tabs=[
        Panel(child=plot, title="single scan"),
        Panel(child=ov_plot, title="overview"),
        Panel(child=ov_param_plot, title="overview map"),
        Panel(child=column(param_plot, row(fit_param_select)),
              title="parameter plot"),
    ])

    # Scan select
    def scan_table_select_callback(_attr, old, new):
        if not new:
            # skip empty selections
            return

        # Avoid selection of multiple indicies (via Shift+Click or Ctrl+Click)
        if len(new) > 1:
            # drop selection to the previous one
            scan_table_source.selected.indices = old
            return

        if len(old) > 1:
            # skip unnecessary update caused by selection drop
            return

        _update_plot()

    def scan_table_source_callback(_attr, _old, _new):
        _update_preview()

    scan_table_source = ColumnDataSource(
        dict(file=[], scan=[], param=[], fit=[], export=[]))
    scan_table_source.on_change("data", scan_table_source_callback)

    scan_table = DataTable(
        source=scan_table_source,
        columns=[
            TableColumn(field="file", title="file", width=150),
            TableColumn(field="scan", title="scan", width=50),
            TableColumn(field="param",
                        title="param",
                        editor=NumberEditor(),
                        width=50),
            TableColumn(field="fit", title="Fit", width=50),
            TableColumn(field="export",
                        title="Export",
                        editor=CheckboxEditor(),
                        width=50),
        ],
        width=410,  # +60 because of the index column
        editable=True,
        autosize_mode="none",
    )

    def scan_table_source_callback(_attr, _old, _new):
        if scan_table_source.selected.indices:
            _update_plot()

    scan_table_source.selected.on_change("indices", scan_table_select_callback)
    scan_table_source.on_change("data", scan_table_source_callback)

    def _get_selected_scan():
        return det_data[scan_table_source.selected.indices[0]]

    def param_select_callback(_attr, _old, new):
        if new == "user defined":
            param = [None] * len(det_data)
        else:
            param = [scan[new] for scan in det_data]

        scan_table_source.data["param"] = param
        _update_param_plot()

    param_select = Select(
        title="Parameter:",
        options=["user defined", "temp", "mf", "h", "k", "l"],
        value="user defined",
        width=145,
    )
    param_select.on_change("value", param_select_callback)

    def fit_from_spinner_callback(_attr, _old, new):
        fit_from_span.location = new

    fit_from_spinner = Spinner(title="Fit from:", width=145)
    fit_from_spinner.on_change("value", fit_from_spinner_callback)

    def fit_to_spinner_callback(_attr, _old, new):
        fit_to_span.location = new

    fit_to_spinner = Spinner(title="to:", width=145)
    fit_to_spinner.on_change("value", fit_to_spinner_callback)

    def fitparams_add_dropdown_callback(click):
        # bokeh requires (str, str) for MultiSelect options
        new_tag = f"{click.item}-{fitparams_select.tags[0]}"
        fitparams_select.options.append((new_tag, click.item))
        fit_params[new_tag] = fitparams_factory(click.item)
        fitparams_select.tags[0] += 1

    fitparams_add_dropdown = Dropdown(
        label="Add fit function",
        menu=[
            ("Linear", "linear"),
            ("Gaussian", "gaussian"),
            ("Voigt", "voigt"),
            ("Pseudo Voigt", "pvoigt"),
            # ("Pseudo Voigt1", "pseudovoigt1"),
        ],
        width=145,
    )
    fitparams_add_dropdown.on_click(fitparams_add_dropdown_callback)

    def fitparams_select_callback(_attr, old, new):
        # Avoid selection of multiple indicies (via Shift+Click or Ctrl+Click)
        if len(new) > 1:
            # drop selection to the previous one
            fitparams_select.value = old
            return

        if len(old) > 1:
            # skip unnecessary update caused by selection drop
            return

        if new:
            fitparams_table_source.data.update(fit_params[new[0]])
        else:
            fitparams_table_source.data.update(
                dict(param=[], value=[], vary=[], min=[], max=[]))

    fitparams_select = MultiSelect(options=[], height=120, width=145)
    fitparams_select.tags = [0]
    fitparams_select.on_change("value", fitparams_select_callback)

    def fitparams_remove_button_callback():
        if fitparams_select.value:
            sel_tag = fitparams_select.value[0]
            del fit_params[sel_tag]
            for elem in fitparams_select.options:
                if elem[0] == sel_tag:
                    fitparams_select.options.remove(elem)
                    break

            fitparams_select.value = []

    fitparams_remove_button = Button(label="Remove fit function", width=145)
    fitparams_remove_button.on_click(fitparams_remove_button_callback)

    def fitparams_factory(function):
        if function == "linear":
            params = ["slope", "intercept"]
        elif function == "gaussian":
            params = ["amplitude", "center", "sigma"]
        elif function == "voigt":
            params = ["amplitude", "center", "sigma", "gamma"]
        elif function == "pvoigt":
            params = ["amplitude", "center", "sigma", "fraction"]
        elif function == "pseudovoigt1":
            params = ["amplitude", "center", "g_sigma", "l_sigma", "fraction"]
        else:
            raise ValueError("Unknown fit function")

        n = len(params)
        fitparams = dict(
            param=params,
            value=[None] * n,
            vary=[True] * n,
            min=[None] * n,
            max=[None] * n,
        )

        if function == "linear":
            fitparams["value"] = [0, 1]
            fitparams["vary"] = [False, True]
            fitparams["min"] = [None, 0]

        elif function == "gaussian":
            fitparams["min"] = [0, None, None]

        return fitparams

    fitparams_table_source = ColumnDataSource(
        dict(param=[], value=[], vary=[], min=[], max=[]))
    fitparams_table = DataTable(
        source=fitparams_table_source,
        columns=[
            TableColumn(field="param", title="Parameter"),
            TableColumn(field="value", title="Value", editor=NumberEditor()),
            TableColumn(field="vary", title="Vary", editor=CheckboxEditor()),
            TableColumn(field="min", title="Min", editor=NumberEditor()),
            TableColumn(field="max", title="Max", editor=NumberEditor()),
        ],
        height=200,
        width=350,
        index_position=None,
        editable=True,
        auto_edit=True,
    )

    # start with `background` and `gauss` fit functions added
    fitparams_add_dropdown_callback(types.SimpleNamespace(item="linear"))
    fitparams_add_dropdown_callback(types.SimpleNamespace(item="gaussian"))
    fitparams_select.value = ["gaussian-1"]  # add selection to gauss

    fit_output_textinput = TextAreaInput(title="Fit results:",
                                         width=750,
                                         height=200)

    def proc_all_button_callback():
        for scan, export in zip(det_data, scan_table_source.data["export"]):
            if export:
                pyzebra.fit_scan(scan,
                                 fit_params,
                                 fit_from=fit_from_spinner.value,
                                 fit_to=fit_to_spinner.value)
                pyzebra.get_area(
                    scan,
                    area_method=AREA_METHODS[area_method_radiobutton.active],
                    lorentz=lorentz_checkbox.active,
                )

        _update_plot()
        _update_table()

        for scan in det_data:
            if "fit" in scan:
                options = list(scan["fit"].params.keys())
                fit_param_select.options = options
                fit_param_select.value = options[0]
                break
        _update_param_plot()

    proc_all_button = Button(label="Process All",
                             button_type="primary",
                             width=145)
    proc_all_button.on_click(proc_all_button_callback)

    def proc_button_callback():
        scan = _get_selected_scan()
        pyzebra.fit_scan(scan,
                         fit_params,
                         fit_from=fit_from_spinner.value,
                         fit_to=fit_to_spinner.value)
        pyzebra.get_area(
            scan,
            area_method=AREA_METHODS[area_method_radiobutton.active],
            lorentz=lorentz_checkbox.active,
        )

        _update_plot()
        _update_table()

        for scan in det_data:
            if "fit" in scan:
                options = list(scan["fit"].params.keys())
                fit_param_select.options = options
                fit_param_select.value = options[0]
                break
        _update_param_plot()

    proc_button = Button(label="Process Current", width=145)
    proc_button.on_click(proc_button_callback)

    area_method_div = Div(text="Intensity:", margin=(5, 5, 0, 5))
    area_method_radiobutton = RadioGroup(labels=["Function", "Area"],
                                         active=0,
                                         width=145)

    lorentz_checkbox = CheckboxGroup(labels=["Lorentz Correction"],
                                     width=145,
                                     margin=(13, 5, 5, 5))

    export_preview_textinput = TextAreaInput(title="Export file preview:",
                                             width=450,
                                             height=400)

    def _update_preview():
        with tempfile.TemporaryDirectory() as temp_dir:
            temp_file = temp_dir + "/temp"
            export_data = []
            for s, export in zip(det_data, scan_table_source.data["export"]):
                if export:
                    export_data.append(s)

            # pyzebra.export_1D(export_data, temp_file, "fullprof")

            exported_content = ""
            file_content = []
            for ext in (".comm", ".incomm"):
                fname = temp_file + ext
                if os.path.isfile(fname):
                    with open(fname) as f:
                        content = f.read()
                        exported_content += f"{ext} file:\n" + content
                else:
                    content = ""
                file_content.append(content)

            js_data.data.update(content=file_content)
            export_preview_textinput.value = exported_content

    save_button = Button(label="Download File(s)",
                         button_type="success",
                         width=220)
    save_button.js_on_click(
        CustomJS(args={"js_data": js_data}, code=javaScript))

    fitpeak_controls = row(
        column(fitparams_add_dropdown, fitparams_select,
               fitparams_remove_button),
        fitparams_table,
        Spacer(width=20),
        column(fit_from_spinner, lorentz_checkbox, area_method_div,
               area_method_radiobutton),
        column(fit_to_spinner, proc_button, proc_all_button),
    )

    scan_layout = column(scan_table,
                         row(monitor_spinner, scan_motor_select, param_select))

    import_layout = column(
        proposal_textinput,
        file_select,
        row(file_open_button, file_append_button),
        upload_div,
        upload_button,
        append_upload_div,
        append_upload_button,
    )

    export_layout = column(export_preview_textinput, row(save_button))

    tab_layout = column(
        row(import_layout, scan_layout, plots, Spacer(width=30),
            export_layout),
        row(fitpeak_controls, fit_output_textinput),
    )

    return Panel(child=tab_layout, title="param study")