Exemplo n.º 1
0
def modify_doc(doc):
    source = ColumnDataSource(dict(x=[1, 2], y=[1, 1], val=["a", "b"]))
    plot = Plot(plot_height=400,
                plot_width=400,
                x_range=Range1d(0, 1),
                y_range=Range1d(0, 1),
                min_border=0)

    plot.add_glyph(source, Circle(x='x', y='y'))
    plot.add_tools(
        CustomAction(callback=CustomJS(args=dict(s=source),
                                       code=RECORD("data", "s.data"))))
    spinner = Spinner(low=-1,
                      high=10,
                      step=0.1,
                      value=4,
                      css_classes=["foo"],
                      format="0[.]0")

    def cb(attr, old, new):
        source.data['val'] = [old, new]

    spinner.on_change('value', cb)
    doc.add_root(column(spinner, plot))
    return doc
Exemplo n.º 2
0
 def create_plot_width_input(self) -> TextInput:
     plot_width = self.figure.plot_width
     plot_width_input = Spinner(title="Width",
                                value=plot_width,
                                step=5,
                                width=100)
     plot_width_input.on_change("value", self.handle_plot_width_change)
     return plot_width_input
Exemplo n.º 3
0
 def create_radius_input(self) -> TextInput:
     radius = self.plot.glyph.radius
     radius_input = Spinner(value=radius,
                            title="Radius",
                            step=0.01,
                            width=100)
     radius_input.on_change("value", self.handle_radius_change)
     return radius_input
Exemplo n.º 4
0
 def create_plot_height_input(self) -> TextInput:
     plot_height = self.figure.plot_height
     plot_height_input = Spinner(title="Height",
                                 value=plot_height,
                                 step=5,
                                 width=100)
     plot_height_input.on_change("value", self.handle_plot_height_change)
     return plot_height_input
Exemplo n.º 5
0
 def create_range_end_input(self, dimension: str) -> Spinner:
     axes_range = self.get_axes_range(dimension)
     disabled = isinstance(axes_range, FactorRange)
     spinner = Spinner(title="End",
                       value=axes_range.end,
                       step=0.05,
                       width=100,
                       disabled=disabled)
     spinner.on_change("value",
                       partial(self.handle_range_end_change, dimension))
     return spinner
Exemplo n.º 6
0
def modify_doc(doc):
    source = ColumnDataSource(dict(x=[1, 2], y=[1, 1], val=["a", "b"]))
    plot = Plot(plot_height=400, plot_width=400, x_range=Range1d(0, 1), y_range=Range1d(0, 1), min_border=0)

    plot.add_glyph(source, Circle(x='x', y='y'))
    plot.add_tools(CustomAction(callback=CustomJS(args=dict(s=source), code=RECORD("data", "s.data"))))
    spinner = Spinner(low=-1, high=10, step=0.1, value=4, css_classes=["foo"])

    def cb(attr, old, new):
        source.data['val'] = [old, new]

    spinner.on_change('value', cb)
    doc.add_root(column(spinner, plot))
    return doc
def wcushow(doc):
    '''
    bokeh function server
    :param doc: bokeh document
    :return: updated document
    '''
    doc.theme = settings.colortheme
    TOOLTIPS = [
        ("(x,y)", "($x, $y)"),
    ]

    renderer = 'webgl'
    graphTools = 'pan,wheel_zoom,box_zoom,zoom_in,zoom_out,hover,crosshair,undo,redo,reset,save'

    cabecalho = ''
    channelCounter = 0
    WCUconfig = pd.read_csv('./projectfolder/configuration/dataWCU.csv',
                            sep=';',
                            index_col=False)
    CANconfig = pd.read_csv('./projectfolder/configuration/configCAN.csv',
                            sep=';',
                            index_col=False)

    for channel in WCUconfig['channel']:
        channelCounter = channelCounter + 1
        if (len(WCUconfig) > channelCounter):
            cabecalho = cabecalho + channel + ','
        else:
            cabecalho = cabecalho + channel

    # test temp data folders for wcu
    if os.path.isdir('./_wcu_cacheFiles_'):
        print('./_wcu_cacheFiles_ ok')
    else:
        os.mkdir('./_wcu_cacheFiles_')

    #set COM port baudrate
    baudrate = settings.boudrateselected
    wcuUpdateTimer = 1 / 5  #second -> 1/FPS

    portWCU = settings.port

    #conect to WDU and clean garbage
    comport = connectSerial(portWCU, baudrate)
    cleanCOMPORT(comport=comport)

    #create csv file for writing WCU data
    global wcufilename
    wcufilename = createCSV(cabecalho)

    time.sleep(
        2
    )  #to start serial, requires an delay to arduino load data at the buffer
    #get WDU data

    global data, lastupdate
    lastupdate = '0.0'
    for channel in CANconfig['Channel']:
        lastupdate = lastupdate + ',0.0'

    wcufile = open(wcufilename, "at")
    data, lastupdate = updateWCUcsv(seconds=wcuUpdateTimer,
                                    wcufile=wcufile,
                                    comport=comport,
                                    header=cabecalho,
                                    canconfig=CANconfig,
                                    laststr=lastupdate)
    source = ColumnDataSource(data=data)

    # Start Gauges
    gg = 0

    #start variables for an array of gauges
    plot = []
    linemin = []
    linemax = []
    valueglyph = []

    #for each channel listed in the wcucsv that have an 'true' marked on the display column,
    # there will be a plot
    text_data_s = []
    text_unit_s = []
    text_name_s = []
    text_color_s = []
    texts = []
    texplot = figure()
    for channel in WCUconfig['channel']:
        if (len(data[channel]) > 3):  #remove errors
            line = WCUconfig.iloc[gg]
            if (line['display'] == 'gauge'):
                dataValue = pd.to_numeric(data[channel])[len(data[channel]) -
                                                         1]
                plt, mx, mi, vg = plotGauge(dataValue,
                                            unit=line['unit'],
                                            name=channel,
                                            color=line['color'],
                                            offset=line['minvalue'],
                                            maxValue=line['maxvalue'],
                                            major_step=line['majorstep'],
                                            minor_step=line['minorstep'])
                plot.append(plt)
                linemax.append(mx)
                linemin.append(mi)
                valueglyph.append(vg)
            if (line['display'] == 'text'):
                dataValue = pd.to_numeric(data[channel])[len(data[channel]) -
                                                         1]
                text_data_s.append(dataValue)
                text_unit_s.append(line['unit'])
                text_name_s.append(channel)
                text_color_s.append(line['color'])

        texplot, texts = plot_text_data(text_data_s,
                                        unit=text_unit_s,
                                        name=text_name_s,
                                        color=text_color_s)
        gg = gg + 1

    # Other main plots
    #GPS:
    track, points, livepoint = gps(data, singleplot=True, H=300, W=300)
    gpssource = points.data_source
    livesource = livepoint.data_source

    #Steering for steering angle
    steering, steering_image, steering_text = plot_angle_image()

    # line plots: secondary tabs
    renderer = 'webgl'
    p = figure(
        plot_height=300,
        plot_width=1000,
        y_range=(0, 13000),
        title='RPM',
        x_axis_label='s',
        y_axis_label='rpm',
        toolbar_location="below",
        tooltips=TOOLTIPS,
        output_backend=renderer,
        tools=graphTools,
    )
    g = p.line(x='time', y='RPM', color='red', source=source)
    p.toolbar.logo = None

    #function for update all live gauges and graphics
    global wcufileglobal
    try:
        wcufile.close()
        wcufile = open(wcufilename, "at")
    except:
        pass

    def update_data():
        #t1_start = process_time()
        global data, lastupdate, lastline
        lastline = data.iloc[[data.ndim - 1]].to_csv(header=False,
                                                     index=False).strip('\r\n')
        data, lastupdate = updateWCUcsv(seconds=wcuUpdateTimer,
                                        wcufile=wcufile,
                                        comport=comport,
                                        header=cabecalho,
                                        canconfig=CANconfig,
                                        laststr=lastupdate)

        #t1_stop = process_time()
        #print("HEY: {:.9f}".format((t1_stop - t1_start)))

    def update_source():
        #t1_start = process_time()
        global data
        data = wcu_equations(data)
        source.data = data
        #t1_stop = process_time()
        #print("HEYHEYHEYHEYHEYHEY: {:.9f}".format((t1_stop - t1_start)))

    def callback():
        '''
        callback function to update bokeh server
        :return: none
        '''

        #df = source.to_df()
        #lastline = data.iloc[[df.ndim - 1]].to_csv(header=False, index=False).strip('\r\n')
        #lastupdate = ',' + ','.join(lastline.split(',')[(-len(CANconfig['Channel'])):])

        us = partial(update_source)
        doc.add_next_tick_callback(us)

        ud = partial(update_data)
        doc.add_next_tick_callback(ud)

        global data, lastupdate, lastline

        #alternative method
        #data = source.to_df()
        #lastline = df.iloc[[df.ndim - 1]].to_csv(header=False, index=False).strip('\r\n')
        #lastupdate = ',' + ','.join(lastline.split(',')[(-len(CANconfig['Channel'])):])

        gg = 0
        linectr = 0
        text_values_update = []
        text_unit_s_update = []
        text_name_s_update = []
        for channel in WCUconfig['channel']:
            line = WCUconfig.iloc[gg]
            if (line['display'] == 'gauge'):
                dataValue = pd.to_numeric(data[channel])[len(data[channel]) -
                                                         1]
                angle = speed_to_angle(dataValue,
                                       offset=line['minvalue'],
                                       max_value=line['maxvalue'])
                linemax[linectr].update(angle=angle)
                linemin[linectr].update(angle=angle - pi)
                valueglyph[linectr].update(
                    text=[str(round(dataValue, 1)) + ' ' + line['unit']])
                linectr = linectr + 1

            if (line['display'] == 'text'):
                text_values_update.append(
                    pd.to_numeric(data[channel])[len(data[channel]) - 1])
                text_unit_s_update.append(line['unit'])
                text_name_s_update.append(channel)
            gg = gg + 1

        for i in range(0, len(texts)):
            texts[i].update(text=[
                text_name_s_update[i] + ': ' + str(text_values_update[i]) +
                text_unit_s_update[i]
            ])

        steeringangle = pd.to_numeric(
            data['SteeringAngle'])[len(data['SteeringAngle']) - 1]
        #steering_image.update(angle = steeringangle)
        steering_text.update(
            text=['Steering Angle' + ': ' + str(steeringangle) + 'deg'])

        lat, long = latlong(data)
        gpssource.data.update(x=lat, y=long)
        livesource.data.update(x=lat.iloc[-1:], y=long.iloc[-1:])

    global per_call
    #per_call = doc.add_periodic_callback(update_source, wcuUpdateTimer*1001)
    per_call = doc.add_periodic_callback(callback, wcuUpdateTimer * 1000)
    '''
    #Button to stop the server
    
    def exit_callback():
        doc.remove_periodic_callback(per_call)
        endwculog(wcufilename)

     
    button = Button(label="Stop", button_type="success")
    button.on_click(exit_callback)
    doc.add_root(button)
    '''
    #pre = PreText(text="""Select Witch Channels to Watch""",width=500, height=100)

    global type_graph_option, graph_points_size
    graph_points_size = 2
    type_graph_option = 0

    def addGraph(attrname, old, new):
        '''
        callback function to add graphs figure in the tab area for plotting
        :param attrname:
        :param old: old user selection
        :param new: new user selected channels
        :return: none
        '''

        global old_ch, new_ch
        if len(old) < 5:
            old_ch = old
            new_ch = new
        else:
            new_ch = new

        if len(new) < 5:
            uptab = doc.get_model_by_name('graphtab')
            for channel in old_ch:
                tb = doc.get_model_by_name('graphtab')
                if channel != '':
                    tb.child.children.remove(
                        tb.child.children[len(tb.child.children) - 1])
            for channel in new_ch:
                plot = figure(plot_height=300,
                              plot_width=1300,
                              title=channel,
                              x_axis_label='s',
                              y_axis_label=channel,
                              toolbar_location="below",
                              tooltips=TOOLTIPS,
                              output_backend=renderer,
                              tools=graphTools,
                              name=channel)
                global type_graph_option, graph_points_size

                if type_graph_option == 0:
                    plot.line(x='time', y=channel, color='red', source=source)
                if type_graph_option == 1:
                    plot.circle(x='time',
                                y=channel,
                                color='red',
                                source=source,
                                size=graph_points_size)
                plot.toolbar.logo = None
                uptab.child.children.append(plot)
        else:
            error_2_wcu()

    def radio_group_options(attrname, old, new):
        global type_graph_option
        type_graph_option = new

    OPTIONS_LABEL = ["Line Graph", "Circle Points"]
    radio_group = RadioGroup(labels=OPTIONS_LABEL, active=0)
    radio_group.on_change("active", radio_group_options)

    OPTIONS = cabecalho.split(',')
    multi_select = MultiSelect(value=[''],
                               options=OPTIONS,
                               title='Select Channels',
                               width=300,
                               height=300)
    multi_select.on_change("value", addGraph)

    def update_datapoints(attrname, old, new):
        settings.telemetry_points = new
        if (new > 5000):
            warning_1_wcu()

    datasize_spinner = Spinner(title="Data Points Size",
                               low=1000,
                               high=10000,
                               step=1000,
                               value=1000,
                               width=80)
    datasize_spinner.on_change("value", update_datapoints)

    def update_graph_points_size(attrname, old, new):
        global graph_points_size
        graph_points_size = new

    graph_points_size_spinner = Spinner(title="Circle Size",
                                        low=1,
                                        high=10,
                                        step=1,
                                        value=graph_points_size,
                                        width=80)
    datasize_spinner.on_change("value", update_graph_points_size)

    #make the grid plot of all gauges at the main tab
    Gauges = gridplot([[plot[0], plot[1], plot[4], plot[3]],
                       [plot[6], plot[2], plot[5], plot[7]],
                       [plot[8], plot[9], plot[10], plot[11]]],
                      toolbar_options={'logo': None})

    #addGraph()
    Graphs = (p)
    layoutGraphs = layout(
        row(Graphs, multi_select,
            column(datasize_spinner, radio_group, graph_points_size_spinner)))
    layoutGauges = layout(row(Gauges, column(track, steering, texplot)))

    Gauges = Panel(child=layoutGauges, title="Gauges", closable=True)
    Graphs = Panel(child=layoutGraphs,
                   title="Graphs",
                   closable=True,
                   name='graphtab')

    def cleanup_session(session_context):
        '''
        This function is called when a session is closed: callback function to end WCU Bokeh server
        :return: none
        '''
        endWCU()

    #activate callback to detect when browser is closed
    doc.on_session_destroyed(cleanup_session)

    tabs = Tabs(tabs=[
        Gauges,
        Graphs,
    ], name='WCU TABS')

    doc.add_root(tabs)
    doc.title = 'WCU SCREEN'
    return doc
Exemplo n.º 8
0
 def create_size_spinner(self) -> Spinner:
     value = int(self.text_font_size[:-2])
     spinner = Spinner(title="Size", value=value, low=0, step=1, width=100)
     spinner.on_change("value", self.handle_size_change)
     return spinner
Exemplo n.º 9
0
 def create_alpha_spinner(self) -> Spinner:
     spinner = Spinner(
         title="Alpha", value=self.text_alpha, low=0, high=1, step=0.05, width=100
     )
     spinner.on_change("value", self.handle_alpha_change)
     return spinner
def create():
    doc = curdoc()
    det_data = {}
    cami_meta = {}

    def proposal_textinput_callback(_attr, _old, new):
        nonlocal cami_meta
        proposal = new.strip()
        for zebra_proposals_path in pyzebra.ZEBRA_PROPOSALS_PATHS:
            proposal_path = os.path.join(zebra_proposals_path, proposal)
            if os.path.isdir(proposal_path):
                # found it
                break
        else:
            raise ValueError(f"Can not find data for proposal '{proposal}'.")

        file_list = []
        for file in os.listdir(proposal_path):
            if file.endswith(".hdf"):
                file_list.append((os.path.join(proposal_path, file), file))
        file_select.options = file_list

        cami_meta = {}

    proposal_textinput = TextInput(title="Proposal number:", width=210)
    proposal_textinput.on_change("value", proposal_textinput_callback)

    def upload_button_callback(_attr, _old, new):
        nonlocal cami_meta
        with io.StringIO(base64.b64decode(new).decode()) as file:
            cami_meta = pyzebra.parse_h5meta(file)
            file_list = cami_meta["filelist"]
            file_select.options = [(entry, os.path.basename(entry))
                                   for entry in file_list]

    upload_div = Div(text="or upload .cami file:", margin=(5, 5, 0, 5))
    upload_button = FileInput(accept=".cami", width=200)
    upload_button.on_change("value", upload_button_callback)

    def update_image(index=None):
        if index is None:
            index = index_spinner.value

        current_image = det_data["data"][index]
        proj_v_line_source.data.update(x=np.arange(0, IMAGE_W) + 0.5,
                                       y=np.mean(current_image, axis=0))
        proj_h_line_source.data.update(x=np.mean(current_image, axis=1),
                                       y=np.arange(0, IMAGE_H) + 0.5)

        image_source.data.update(
            h=[np.zeros((1, 1))],
            k=[np.zeros((1, 1))],
            l=[np.zeros((1, 1))],
        )
        image_source.data.update(image=[current_image])

        if main_auto_checkbox.active:
            im_min = np.min(current_image)
            im_max = np.max(current_image)

            display_min_spinner.value = im_min
            display_max_spinner.value = im_max

            image_glyph.color_mapper.low = im_min
            image_glyph.color_mapper.high = im_max

        if "mf" in det_data:
            metadata_table_source.data.update(mf=[det_data["mf"][index]])
        else:
            metadata_table_source.data.update(mf=[None])

        if "temp" in det_data:
            metadata_table_source.data.update(temp=[det_data["temp"][index]])
        else:
            metadata_table_source.data.update(temp=[None])

        gamma, nu = calculate_pol(det_data, index)
        omega = np.ones((IMAGE_H, IMAGE_W)) * det_data["omega"][index]
        image_source.data.update(gamma=[gamma], nu=[nu], omega=[omega])

    def update_overview_plot():
        h5_data = det_data["data"]
        n_im, n_y, n_x = h5_data.shape
        overview_x = np.mean(h5_data, axis=1)
        overview_y = np.mean(h5_data, axis=2)

        overview_plot_x_image_source.data.update(image=[overview_x],
                                                 dw=[n_x],
                                                 dh=[n_im])
        overview_plot_y_image_source.data.update(image=[overview_y],
                                                 dw=[n_y],
                                                 dh=[n_im])

        if proj_auto_checkbox.active:
            im_min = min(np.min(overview_x), np.min(overview_y))
            im_max = max(np.max(overview_x), np.max(overview_y))

            proj_display_min_spinner.value = im_min
            proj_display_max_spinner.value = im_max

            overview_plot_x_image_glyph.color_mapper.low = im_min
            overview_plot_y_image_glyph.color_mapper.low = im_min
            overview_plot_x_image_glyph.color_mapper.high = im_max
            overview_plot_y_image_glyph.color_mapper.high = im_max

        frame_range.start = 0
        frame_range.end = n_im
        frame_range.reset_start = 0
        frame_range.reset_end = n_im
        frame_range.bounds = (0, n_im)

        scan_motor = det_data["scan_motor"]
        overview_plot_y.axis[1].axis_label = f"Scanning motor, {scan_motor}"

        var = det_data[scan_motor]
        var_start = var[0]
        var_end = var[-1] + (var[-1] - var[0]) / (n_im - 1)

        scanning_motor_range.start = var_start
        scanning_motor_range.end = var_end
        scanning_motor_range.reset_start = var_start
        scanning_motor_range.reset_end = var_end
        # handle both, ascending and descending sequences
        scanning_motor_range.bounds = (min(var_start,
                                           var_end), max(var_start, var_end))

    def file_select_callback(_attr, old, new):
        nonlocal det_data
        if not new:
            # skip empty selections
            return

        # Avoid selection of multiple indicies (via Shift+Click or Ctrl+Click)
        if len(new) > 1:
            # drop selection to the previous one
            file_select.value = old
            return

        if len(old) > 1:
            # skip unnecessary update caused by selection drop
            return

        det_data = pyzebra.read_detector_data(new[0])

        if cami_meta and "crystal" in cami_meta:
            det_data["ub"] = cami_meta["crystal"]["UB"]

        index_spinner.value = 0
        index_spinner.high = det_data["data"].shape[0] - 1
        index_slider.end = det_data["data"].shape[0] - 1

        zebra_mode = det_data["zebra_mode"]
        if zebra_mode == "nb":
            metadata_table_source.data.update(geom=["normal beam"])
        else:  # zebra_mode == "bi"
            metadata_table_source.data.update(geom=["bisecting"])

        update_image(0)
        update_overview_plot()

    file_select = MultiSelect(title="Available .hdf files:",
                              width=210,
                              height=250)
    file_select.on_change("value", file_select_callback)

    def index_callback(_attr, _old, new):
        update_image(new)

    index_slider = Slider(value=0, start=0, end=1, show_value=False, width=400)

    index_spinner = Spinner(title="Image index:", value=0, low=0, width=100)
    index_spinner.on_change("value", index_callback)

    index_slider.js_link("value_throttled", index_spinner, "value")
    index_spinner.js_link("value", index_slider, "value")

    plot = Plot(
        x_range=Range1d(0, IMAGE_W, bounds=(0, IMAGE_W)),
        y_range=Range1d(0, IMAGE_H, bounds=(0, IMAGE_H)),
        plot_height=IMAGE_PLOT_H,
        plot_width=IMAGE_PLOT_W,
        toolbar_location="left",
    )

    # ---- tools
    plot.toolbar.logo = None

    # ---- axes
    plot.add_layout(LinearAxis(), place="above")
    plot.add_layout(LinearAxis(major_label_orientation="vertical"),
                    place="right")

    # ---- grid lines
    plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- rgba image glyph
    image_source = ColumnDataSource(
        dict(
            image=[np.zeros((IMAGE_H, IMAGE_W), dtype="float32")],
            h=[np.zeros((1, 1))],
            k=[np.zeros((1, 1))],
            l=[np.zeros((1, 1))],
            gamma=[np.zeros((1, 1))],
            nu=[np.zeros((1, 1))],
            omega=[np.zeros((1, 1))],
            x=[0],
            y=[0],
            dw=[IMAGE_W],
            dh=[IMAGE_H],
        ))

    h_glyph = Image(image="h", x="x", y="y", dw="dw", dh="dh", global_alpha=0)
    k_glyph = Image(image="k", x="x", y="y", dw="dw", dh="dh", global_alpha=0)
    l_glyph = Image(image="l", x="x", y="y", dw="dw", dh="dh", global_alpha=0)
    gamma_glyph = Image(image="gamma",
                        x="x",
                        y="y",
                        dw="dw",
                        dh="dh",
                        global_alpha=0)
    nu_glyph = Image(image="nu",
                     x="x",
                     y="y",
                     dw="dw",
                     dh="dh",
                     global_alpha=0)
    omega_glyph = Image(image="omega",
                        x="x",
                        y="y",
                        dw="dw",
                        dh="dh",
                        global_alpha=0)

    plot.add_glyph(image_source, h_glyph)
    plot.add_glyph(image_source, k_glyph)
    plot.add_glyph(image_source, l_glyph)
    plot.add_glyph(image_source, gamma_glyph)
    plot.add_glyph(image_source, nu_glyph)
    plot.add_glyph(image_source, omega_glyph)

    image_glyph = Image(image="image", x="x", y="y", dw="dw", dh="dh")
    plot.add_glyph(image_source, image_glyph, name="image_glyph")

    # ---- projections
    proj_v = Plot(
        x_range=plot.x_range,
        y_range=DataRange1d(),
        plot_height=150,
        plot_width=IMAGE_PLOT_W,
        toolbar_location=None,
    )

    proj_v.add_layout(LinearAxis(major_label_orientation="vertical"),
                      place="right")
    proj_v.add_layout(LinearAxis(major_label_text_font_size="0pt"),
                      place="below")

    proj_v.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    proj_v.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    proj_v_line_source = ColumnDataSource(dict(x=[], y=[]))
    proj_v.add_glyph(proj_v_line_source,
                     Line(x="x", y="y", line_color="steelblue"))

    proj_h = Plot(
        x_range=DataRange1d(),
        y_range=plot.y_range,
        plot_height=IMAGE_PLOT_H,
        plot_width=150,
        toolbar_location=None,
    )

    proj_h.add_layout(LinearAxis(), place="above")
    proj_h.add_layout(LinearAxis(major_label_text_font_size="0pt"),
                      place="left")

    proj_h.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    proj_h.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    proj_h_line_source = ColumnDataSource(dict(x=[], y=[]))
    proj_h.add_glyph(proj_h_line_source,
                     Line(x="x", y="y", line_color="steelblue"))

    # add tools
    hovertool = HoverTool(tooltips=[
        ("intensity", "@image"),
        ("gamma", "@gamma"),
        ("nu", "@nu"),
        ("omega", "@omega"),
        ("h", "@h"),
        ("k", "@k"),
        ("l", "@l"),
    ])

    box_edit_source = ColumnDataSource(dict(x=[], y=[], width=[], height=[]))
    box_edit_glyph = Rect(x="x",
                          y="y",
                          width="width",
                          height="height",
                          fill_alpha=0,
                          line_color="red")
    box_edit_renderer = plot.add_glyph(box_edit_source, box_edit_glyph)
    boxedittool = BoxEditTool(renderers=[box_edit_renderer], num_objects=1)

    def box_edit_callback(_attr, _old, new):
        if new["x"]:
            h5_data = det_data["data"]
            x_val = np.arange(h5_data.shape[0])
            left = int(np.floor(new["x"][0]))
            right = int(np.ceil(new["x"][0] + new["width"][0]))
            bottom = int(np.floor(new["y"][0]))
            top = int(np.ceil(new["y"][0] + new["height"][0]))
            y_val = np.sum(h5_data[:, bottom:top, left:right], axis=(1, 2))
        else:
            x_val = []
            y_val = []

        roi_avg_plot_line_source.data.update(x=x_val, y=y_val)

    box_edit_source.on_change("data", box_edit_callback)

    wheelzoomtool = WheelZoomTool(maintain_focus=False)
    plot.add_tools(
        PanTool(),
        BoxZoomTool(),
        wheelzoomtool,
        ResetTool(),
        hovertool,
        boxedittool,
    )
    plot.toolbar.active_scroll = wheelzoomtool

    # shared frame ranges
    frame_range = Range1d(0, 1, bounds=(0, 1))
    scanning_motor_range = Range1d(0, 1, bounds=(0, 1))

    det_x_range = Range1d(0, IMAGE_W, bounds=(0, IMAGE_W))
    overview_plot_x = Plot(
        title=Title(text="Projections on X-axis"),
        x_range=det_x_range,
        y_range=frame_range,
        extra_y_ranges={"scanning_motor": scanning_motor_range},
        plot_height=400,
        plot_width=IMAGE_PLOT_W - 3,
    )

    # ---- tools
    wheelzoomtool = WheelZoomTool(maintain_focus=False)
    overview_plot_x.toolbar.logo = None
    overview_plot_x.add_tools(
        PanTool(),
        BoxZoomTool(),
        wheelzoomtool,
        ResetTool(),
    )
    overview_plot_x.toolbar.active_scroll = wheelzoomtool

    # ---- axes
    overview_plot_x.add_layout(LinearAxis(axis_label="Coordinate X, pix"),
                               place="below")
    overview_plot_x.add_layout(LinearAxis(axis_label="Frame",
                                          major_label_orientation="vertical"),
                               place="left")

    # ---- grid lines
    overview_plot_x.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    overview_plot_x.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- rgba image glyph
    overview_plot_x_image_source = ColumnDataSource(
        dict(image=[np.zeros((1, 1), dtype="float32")],
             x=[0],
             y=[0],
             dw=[IMAGE_W],
             dh=[1]))

    overview_plot_x_image_glyph = Image(image="image",
                                        x="x",
                                        y="y",
                                        dw="dw",
                                        dh="dh")
    overview_plot_x.add_glyph(overview_plot_x_image_source,
                              overview_plot_x_image_glyph,
                              name="image_glyph")

    det_y_range = Range1d(0, IMAGE_H, bounds=(0, IMAGE_H))
    overview_plot_y = Plot(
        title=Title(text="Projections on Y-axis"),
        x_range=det_y_range,
        y_range=frame_range,
        extra_y_ranges={"scanning_motor": scanning_motor_range},
        plot_height=400,
        plot_width=IMAGE_PLOT_H + 22,
    )

    # ---- tools
    wheelzoomtool = WheelZoomTool(maintain_focus=False)
    overview_plot_y.toolbar.logo = None
    overview_plot_y.add_tools(
        PanTool(),
        BoxZoomTool(),
        wheelzoomtool,
        ResetTool(),
    )
    overview_plot_y.toolbar.active_scroll = wheelzoomtool

    # ---- axes
    overview_plot_y.add_layout(LinearAxis(axis_label="Coordinate Y, pix"),
                               place="below")
    overview_plot_y.add_layout(
        LinearAxis(
            y_range_name="scanning_motor",
            axis_label="Scanning motor",
            major_label_orientation="vertical",
        ),
        place="right",
    )

    # ---- grid lines
    overview_plot_y.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    overview_plot_y.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- rgba image glyph
    overview_plot_y_image_source = ColumnDataSource(
        dict(image=[np.zeros((1, 1), dtype="float32")],
             x=[0],
             y=[0],
             dw=[IMAGE_H],
             dh=[1]))

    overview_plot_y_image_glyph = Image(image="image",
                                        x="x",
                                        y="y",
                                        dw="dw",
                                        dh="dh")
    overview_plot_y.add_glyph(overview_plot_y_image_source,
                              overview_plot_y_image_glyph,
                              name="image_glyph")

    roi_avg_plot = Plot(
        x_range=DataRange1d(),
        y_range=DataRange1d(),
        plot_height=150,
        plot_width=IMAGE_PLOT_W,
        toolbar_location="left",
    )

    # ---- tools
    roi_avg_plot.toolbar.logo = None

    # ---- axes
    roi_avg_plot.add_layout(LinearAxis(), place="below")
    roi_avg_plot.add_layout(LinearAxis(major_label_orientation="vertical"),
                            place="left")

    # ---- grid lines
    roi_avg_plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    roi_avg_plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    roi_avg_plot_line_source = ColumnDataSource(dict(x=[], y=[]))
    roi_avg_plot.add_glyph(roi_avg_plot_line_source,
                           Line(x="x", y="y", line_color="steelblue"))

    cmap_dict = {
        "gray": Greys256,
        "gray_reversed": Greys256[::-1],
        "plasma": Plasma256,
        "cividis": Cividis256,
    }

    def colormap_callback(_attr, _old, new):
        image_glyph.color_mapper = LinearColorMapper(palette=cmap_dict[new])
        overview_plot_x_image_glyph.color_mapper = LinearColorMapper(
            palette=cmap_dict[new])
        overview_plot_y_image_glyph.color_mapper = LinearColorMapper(
            palette=cmap_dict[new])

    colormap = Select(title="Colormap:",
                      options=list(cmap_dict.keys()),
                      width=210)
    colormap.on_change("value", colormap_callback)
    colormap.value = "plasma"

    STEP = 1

    def main_auto_checkbox_callback(state):
        if state:
            display_min_spinner.disabled = True
            display_max_spinner.disabled = True
        else:
            display_min_spinner.disabled = False
            display_max_spinner.disabled = False

        update_image()

    main_auto_checkbox = CheckboxGroup(labels=["Main Auto Range"],
                                       active=[0],
                                       width=145,
                                       margin=[10, 5, 0, 5])
    main_auto_checkbox.on_click(main_auto_checkbox_callback)

    def display_max_spinner_callback(_attr, _old_value, new_value):
        display_min_spinner.high = new_value - STEP
        image_glyph.color_mapper.high = new_value

    display_max_spinner = Spinner(
        low=0 + STEP,
        value=1,
        step=STEP,
        disabled=bool(main_auto_checkbox.active),
        width=100,
        height=31,
    )
    display_max_spinner.on_change("value", display_max_spinner_callback)

    def display_min_spinner_callback(_attr, _old_value, new_value):
        display_max_spinner.low = new_value + STEP
        image_glyph.color_mapper.low = new_value

    display_min_spinner = Spinner(
        low=0,
        high=1 - STEP,
        value=0,
        step=STEP,
        disabled=bool(main_auto_checkbox.active),
        width=100,
        height=31,
    )
    display_min_spinner.on_change("value", display_min_spinner_callback)

    PROJ_STEP = 0.1

    def proj_auto_checkbox_callback(state):
        if state:
            proj_display_min_spinner.disabled = True
            proj_display_max_spinner.disabled = True
        else:
            proj_display_min_spinner.disabled = False
            proj_display_max_spinner.disabled = False

        update_overview_plot()

    proj_auto_checkbox = CheckboxGroup(labels=["Projections Auto Range"],
                                       active=[0],
                                       width=145,
                                       margin=[10, 5, 0, 5])
    proj_auto_checkbox.on_click(proj_auto_checkbox_callback)

    def proj_display_max_spinner_callback(_attr, _old_value, new_value):
        proj_display_min_spinner.high = new_value - PROJ_STEP
        overview_plot_x_image_glyph.color_mapper.high = new_value
        overview_plot_y_image_glyph.color_mapper.high = new_value

    proj_display_max_spinner = Spinner(
        low=0 + PROJ_STEP,
        value=1,
        step=PROJ_STEP,
        disabled=bool(proj_auto_checkbox.active),
        width=100,
        height=31,
    )
    proj_display_max_spinner.on_change("value",
                                       proj_display_max_spinner_callback)

    def proj_display_min_spinner_callback(_attr, _old_value, new_value):
        proj_display_max_spinner.low = new_value + PROJ_STEP
        overview_plot_x_image_glyph.color_mapper.low = new_value
        overview_plot_y_image_glyph.color_mapper.low = new_value

    proj_display_min_spinner = Spinner(
        low=0,
        high=1 - PROJ_STEP,
        value=0,
        step=PROJ_STEP,
        disabled=bool(proj_auto_checkbox.active),
        width=100,
        height=31,
    )
    proj_display_min_spinner.on_change("value",
                                       proj_display_min_spinner_callback)

    def hkl_button_callback():
        index = index_spinner.value
        h, k, l = calculate_hkl(det_data, index)
        image_source.data.update(h=[h], k=[k], l=[l])

    hkl_button = Button(label="Calculate hkl (slow)", width=210)
    hkl_button.on_click(hkl_button_callback)

    def events_list_callback(_attr, _old, new):
        doc.events_list_spind.value = new

    events_list = TextAreaInput(rows=7, width=830)
    events_list.on_change("value", events_list_callback)
    doc.events_list_hdf_viewer = events_list

    def add_event_button_callback():
        diff_vec = []
        p0 = [1.0, 0.0, 1.0]
        maxfev = 100000

        wave = det_data["wave"]
        ddist = det_data["ddist"]

        gamma = det_data["gamma"][0]
        omega = det_data["omega"][0]
        nu = det_data["nu"][0]
        chi = det_data["chi"][0]
        phi = det_data["phi"][0]

        scan_motor = det_data["scan_motor"]
        var_angle = det_data[scan_motor]

        x0 = int(np.floor(det_x_range.start))
        xN = int(np.ceil(det_x_range.end))
        y0 = int(np.floor(det_y_range.start))
        yN = int(np.ceil(det_y_range.end))
        fr0 = int(np.floor(frame_range.start))
        frN = int(np.ceil(frame_range.end))
        data_roi = det_data["data"][fr0:frN, y0:yN, x0:xN]

        cnts = np.sum(data_roi, axis=(1, 2))
        coeff, _ = curve_fit(gauss,
                             range(len(cnts)),
                             cnts,
                             p0=p0,
                             maxfev=maxfev)

        m = cnts.mean()
        sd = cnts.std()
        snr_cnts = np.where(sd == 0, 0, m / sd)

        frC = fr0 + coeff[1]
        var_F = var_angle[math.floor(frC)]
        var_C = var_angle[math.ceil(frC)]
        frStep = frC - math.floor(frC)
        var_step = var_C - var_F
        var_p = var_F + var_step * frStep

        if scan_motor == "gamma":
            gamma = var_p
        elif scan_motor == "omega":
            omega = var_p
        elif scan_motor == "nu":
            nu = var_p
        elif scan_motor == "chi":
            chi = var_p
        elif scan_motor == "phi":
            phi = var_p

        intensity = coeff[1] * abs(
            coeff[2] * var_step) * math.sqrt(2) * math.sqrt(np.pi)

        projX = np.sum(data_roi, axis=(0, 1))
        coeff, _ = curve_fit(gauss,
                             range(len(projX)),
                             projX,
                             p0=p0,
                             maxfev=maxfev)
        x_pos = x0 + coeff[1]

        projY = np.sum(data_roi, axis=(0, 2))
        coeff, _ = curve_fit(gauss,
                             range(len(projY)),
                             projY,
                             p0=p0,
                             maxfev=maxfev)
        y_pos = y0 + coeff[1]

        ga, nu = pyzebra.det2pol(ddist, gamma, nu, x_pos, y_pos)
        diff_vector = pyzebra.z1frmd(wave, ga, omega, chi, phi, nu)
        d_spacing = float(pyzebra.dandth(wave, diff_vector)[0])
        diff_vector = diff_vector.flatten() * 1e10
        dv1, dv2, dv3 = diff_vector

        diff_vec.append(diff_vector)

        if events_list.value and not events_list.value.endswith("\n"):
            events_list.value = events_list.value + "\n"

        events_list.value = (
            events_list.value +
            f"{x_pos} {y_pos} {intensity} {snr_cnts} {dv1} {dv2} {dv3} {d_spacing}"
        )

    add_event_button = Button(label="Add spind event")
    add_event_button.on_click(add_event_button_callback)

    metadata_table_source = ColumnDataSource(
        dict(geom=[""], temp=[None], mf=[None]))
    num_formatter = NumberFormatter(format="0.00", nan_format="")
    metadata_table = DataTable(
        source=metadata_table_source,
        columns=[
            TableColumn(field="geom", title="Geometry", width=100),
            TableColumn(field="temp",
                        title="Temperature",
                        formatter=num_formatter,
                        width=100),
            TableColumn(field="mf",
                        title="Magnetic Field",
                        formatter=num_formatter,
                        width=100),
        ],
        width=300,
        height=50,
        autosize_mode="none",
        index_position=None,
    )

    # Final layout
    import_layout = column(proposal_textinput, upload_div, upload_button,
                           file_select)
    layout_image = column(
        gridplot([[proj_v, None], [plot, proj_h]], merge_tools=False))
    colormap_layout = column(
        colormap,
        main_auto_checkbox,
        row(display_min_spinner, display_max_spinner),
        proj_auto_checkbox,
        row(proj_display_min_spinner, proj_display_max_spinner),
    )

    layout_controls = column(
        row(metadata_table, index_spinner,
            column(Spacer(height=25), index_slider)),
        row(add_event_button, hkl_button),
        row(events_list),
    )

    layout_overview = column(
        gridplot(
            [[overview_plot_x, overview_plot_y]],
            toolbar_options=dict(logo=None),
            merge_tools=True,
            toolbar_location="left",
        ), )

    tab_layout = row(
        column(import_layout, colormap_layout),
        column(layout_overview, layout_controls),
        column(roi_avg_plot, layout_image),
    )

    return Panel(child=tab_layout, title="hdf viewer")
Exemplo n.º 11
0
    def __init__(self,
                 nplots,
                 plot_height=350,
                 plot_width=700,
                 lower=0,
                 upper=1000,
                 nbins=100):
        """Initialize histogram plots.

        Args:
            nplots (int): Number of histogram plots that will share common controls.
            plot_height (int, optional): Height of plot area in screen pixels. Defaults to 350.
            plot_width (int, optional): Width of plot area in screen pixels. Defaults to 700.
            lower (int, optional): Initial lower range of the bins. Defaults to 0.
            upper (int, optional): Initial upper range of the bins. Defaults to 1000.
            nbins (int, optional): Initial number of the bins. Defaults to 100.
        """
        # Histogram plots
        self.plots = []
        self._plot_sources = []
        for ind in range(nplots):
            plot = Plot(
                x_range=DataRange1d(),
                y_range=DataRange1d(),
                plot_height=plot_height,
                plot_width=plot_width,
                toolbar_location="left",
            )

            # ---- tools
            plot.toolbar.logo = None
            # share 'pan', 'boxzoom', and 'wheelzoom' tools between all plots
            if ind == 0:
                pantool = PanTool()
                boxzoomtool = BoxZoomTool()
                wheelzoomtool = WheelZoomTool()
            plot.add_tools(pantool, boxzoomtool, wheelzoomtool, SaveTool(),
                           ResetTool())

            # ---- axes
            plot.add_layout(LinearAxis(), place="below")
            plot.add_layout(LinearAxis(major_label_orientation="vertical"),
                            place="left")

            # ---- grid lines
            plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
            plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

            # ---- quad (single bin) glyph
            plot_source = ColumnDataSource(dict(left=[], right=[], top=[]))
            plot.add_glyph(
                plot_source,
                Quad(left="left",
                     right="right",
                     top="top",
                     bottom=0,
                     fill_color="steelblue"),
            )

            self.plots.append(plot)
            self._plot_sources.append(plot_source)

        self._counts = []
        self._empty_counts()

        # Histogram controls
        # ---- histogram range toggle button
        def auto_toggle_callback(state):
            if state:  # Automatic
                lower_spinner.disabled = True
                upper_spinner.disabled = True

            else:  # Manual
                lower_spinner.disabled = False
                upper_spinner.disabled = False

        auto_toggle = CheckboxGroup(labels=["Auto Hist Range"],
                                    active=[0],
                                    default_size=145)
        auto_toggle.on_click(auto_toggle_callback)
        self.auto_toggle = auto_toggle

        # ---- histogram lower range
        def lower_spinner_callback(_attr, _old_value, new_value):
            self.upper_spinner.low = new_value + STEP
            self._empty_counts()

        lower_spinner = Spinner(
            title="Lower Range:",
            high=upper - STEP,
            value=lower,
            step=STEP,
            disabled=bool(auto_toggle.active),
            default_size=145,
        )
        lower_spinner.on_change("value", lower_spinner_callback)
        self.lower_spinner = lower_spinner

        # ---- histogram upper range
        def upper_spinner_callback(_attr, _old_value, new_value):
            self.lower_spinner.high = new_value - STEP
            self._empty_counts()

        upper_spinner = Spinner(
            title="Upper Range:",
            low=lower + STEP,
            value=upper,
            step=STEP,
            disabled=bool(auto_toggle.active),
            default_size=145,
        )
        upper_spinner.on_change("value", upper_spinner_callback)
        self.upper_spinner = upper_spinner

        # ---- histogram number of bins
        def nbins_spinner_callback(_attr, _old_value, _new_value):
            self._empty_counts()

        nbins_spinner = Spinner(title="Number of Bins:",
                                low=1,
                                value=nbins,
                                default_size=145)
        nbins_spinner.on_change("value", nbins_spinner_callback)
        self.nbins_spinner = nbins_spinner

        # ---- histogram log10 of counts toggle button
        def log10counts_toggle_callback(state):
            self._empty_counts()
            for plot in self.plots:
                if state:
                    plot.yaxis[0].axis_label = "log⏨(Counts)"
                else:
                    plot.yaxis[0].axis_label = "Counts"

        log10counts_toggle = CheckboxGroup(labels=["log⏨(Counts)"],
                                           default_size=145)
        log10counts_toggle.on_click(log10counts_toggle_callback)
        self.log10counts_toggle = log10counts_toggle
Exemplo n.º 12
0
 def create_y_control(self) -> tuple:
     y_value = self.plot.glyph.y
     y_control = Spinner(value=y_value, title="y", step=0.05, width=100)
     y_control.on_change("value", partial(self.on_coordinate_change, "y"))
     return y_control
Exemplo n.º 13
0
 def create_x_control(self) -> TextInput:
     x_value = self.plot.glyph.x
     x_control = Spinner(value=x_value, title="x", step=0.05, width=100)
     x_control.on_change("value", partial(self.on_coordinate_change, "x"))
     return x_control
Exemplo n.º 14
0
def create():
    det_data = {}
    roi_selection = {}

    upload_div = Div(text="Open .cami file:")

    def upload_button_callback(_attr, _old, new):
        with io.StringIO(base64.b64decode(new).decode()) as file:
            h5meta_list = pyzebra.parse_h5meta(file)
            file_list = h5meta_list["filelist"]
            filelist.options = file_list
            filelist.value = file_list[0]

    upload_button = FileInput(accept=".cami")
    upload_button.on_change("value", upload_button_callback)

    def update_image(index=None):
        if index is None:
            index = index_spinner.value

        current_image = det_data["data"][index]
        proj_v_line_source.data.update(x=np.arange(0, IMAGE_W) + 0.5,
                                       y=np.mean(current_image, axis=0))
        proj_h_line_source.data.update(x=np.mean(current_image, axis=1),
                                       y=np.arange(0, IMAGE_H) + 0.5)

        image_source.data.update(
            h=[np.zeros((1, 1))],
            k=[np.zeros((1, 1))],
            l=[np.zeros((1, 1))],
        )
        image_source.data.update(image=[current_image])

        if auto_toggle.active:
            im_max = int(np.max(current_image))
            im_min = int(np.min(current_image))

            display_min_spinner.value = im_min
            display_max_spinner.value = im_max

            image_glyph.color_mapper.low = im_min
            image_glyph.color_mapper.high = im_max

    def update_overview_plot():
        h5_data = det_data["data"]
        n_im, n_y, n_x = h5_data.shape
        overview_x = np.mean(h5_data, axis=1)
        overview_y = np.mean(h5_data, axis=2)

        overview_plot_x_image_source.data.update(image=[overview_x], dw=[n_x])
        overview_plot_y_image_source.data.update(image=[overview_y], dw=[n_y])

        if frame_button_group.active == 0:  # Frame
            overview_plot_x.axis[1].axis_label = "Frame"
            overview_plot_y.axis[1].axis_label = "Frame"

            overview_plot_x_image_source.data.update(y=[0], dh=[n_im])
            overview_plot_y_image_source.data.update(y=[0], dh=[n_im])

        elif frame_button_group.active == 1:  # Omega
            overview_plot_x.axis[1].axis_label = "Omega"
            overview_plot_y.axis[1].axis_label = "Omega"

            om = det_data["rot_angle"]
            om_start = om[0]
            om_end = (om[-1] - om[0]) * n_im / (n_im - 1)
            overview_plot_x_image_source.data.update(y=[om_start], dh=[om_end])
            overview_plot_y_image_source.data.update(y=[om_start], dh=[om_end])

    def filelist_callback(_attr, _old, new):
        nonlocal det_data
        det_data = pyzebra.read_detector_data(new)

        index_spinner.value = 0
        index_spinner.high = det_data["data"].shape[0] - 1
        update_image(0)
        update_overview_plot()

    filelist = Select()
    filelist.on_change("value", filelist_callback)

    def index_spinner_callback(_attr, _old, new):
        update_image(new)

    index_spinner = Spinner(title="Image index:", value=0, low=0)
    index_spinner.on_change("value", index_spinner_callback)

    plot = Plot(
        x_range=Range1d(0, IMAGE_W, bounds=(0, IMAGE_W)),
        y_range=Range1d(0, IMAGE_H, bounds=(0, IMAGE_H)),
        plot_height=IMAGE_H * 3,
        plot_width=IMAGE_W * 3,
        toolbar_location="left",
    )

    # ---- tools
    plot.toolbar.logo = None

    # ---- axes
    plot.add_layout(LinearAxis(), place="above")
    plot.add_layout(LinearAxis(major_label_orientation="vertical"),
                    place="right")

    # ---- grid lines
    plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- rgba image glyph
    image_source = ColumnDataSource(
        dict(
            image=[np.zeros((IMAGE_H, IMAGE_W), dtype="float32")],
            h=[np.zeros((1, 1))],
            k=[np.zeros((1, 1))],
            l=[np.zeros((1, 1))],
            x=[0],
            y=[0],
            dw=[IMAGE_W],
            dh=[IMAGE_H],
        ))

    h_glyph = Image(image="h", x="x", y="y", dw="dw", dh="dh", global_alpha=0)
    k_glyph = Image(image="k", x="x", y="y", dw="dw", dh="dh", global_alpha=0)
    l_glyph = Image(image="l", x="x", y="y", dw="dw", dh="dh", global_alpha=0)

    plot.add_glyph(image_source, h_glyph)
    plot.add_glyph(image_source, k_glyph)
    plot.add_glyph(image_source, l_glyph)

    image_glyph = Image(image="image", x="x", y="y", dw="dw", dh="dh")
    plot.add_glyph(image_source, image_glyph, name="image_glyph")

    # ---- projections
    proj_v = Plot(
        x_range=plot.x_range,
        y_range=DataRange1d(),
        plot_height=200,
        plot_width=IMAGE_W * 3,
        toolbar_location=None,
    )

    proj_v.add_layout(LinearAxis(major_label_orientation="vertical"),
                      place="right")
    proj_v.add_layout(LinearAxis(major_label_text_font_size="0pt"),
                      place="below")

    proj_v.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    proj_v.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    proj_v_line_source = ColumnDataSource(dict(x=[], y=[]))
    proj_v.add_glyph(proj_v_line_source,
                     Line(x="x", y="y", line_color="steelblue"))

    proj_h = Plot(
        x_range=DataRange1d(),
        y_range=plot.y_range,
        plot_height=IMAGE_H * 3,
        plot_width=200,
        toolbar_location=None,
    )

    proj_h.add_layout(LinearAxis(), place="above")
    proj_h.add_layout(LinearAxis(major_label_text_font_size="0pt"),
                      place="left")

    proj_h.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    proj_h.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    proj_h_line_source = ColumnDataSource(dict(x=[], y=[]))
    proj_h.add_glyph(proj_h_line_source,
                     Line(x="x", y="y", line_color="steelblue"))

    # add tools
    hovertool = HoverTool(tooltips=[("intensity",
                                     "@image"), ("h", "@h"), ("k",
                                                              "@k"), ("l",
                                                                      "@l")])

    box_edit_source = ColumnDataSource(dict(x=[], y=[], width=[], height=[]))
    box_edit_glyph = Rect(x="x",
                          y="y",
                          width="width",
                          height="height",
                          fill_alpha=0,
                          line_color="red")
    box_edit_renderer = plot.add_glyph(box_edit_source, box_edit_glyph)
    boxedittool = BoxEditTool(renderers=[box_edit_renderer], num_objects=1)

    def box_edit_callback(_attr, _old, new):
        if new["x"]:
            h5_data = det_data["data"]
            x_val = np.arange(h5_data.shape[0])
            left = int(np.floor(new["x"][0]))
            right = int(np.ceil(new["x"][0] + new["width"][0]))
            bottom = int(np.floor(new["y"][0]))
            top = int(np.ceil(new["y"][0] + new["height"][0]))
            y_val = np.sum(h5_data[:, bottom:top, left:right], axis=(1, 2))
        else:
            x_val = []
            y_val = []

        roi_avg_plot_line_source.data.update(x=x_val, y=y_val)

    box_edit_source.on_change("data", box_edit_callback)

    wheelzoomtool = WheelZoomTool(maintain_focus=False)
    plot.add_tools(
        PanTool(),
        BoxZoomTool(),
        wheelzoomtool,
        ResetTool(),
        hovertool,
        boxedittool,
    )
    plot.toolbar.active_scroll = wheelzoomtool

    # shared frame range
    frame_range = DataRange1d()
    det_x_range = DataRange1d()
    overview_plot_x = Plot(
        title=Title(text="Projections on X-axis"),
        x_range=det_x_range,
        y_range=frame_range,
        plot_height=400,
        plot_width=400,
        toolbar_location="left",
    )

    # ---- tools
    wheelzoomtool = WheelZoomTool(maintain_focus=False)
    overview_plot_x.toolbar.logo = None
    overview_plot_x.add_tools(
        PanTool(),
        BoxZoomTool(),
        wheelzoomtool,
        ResetTool(),
    )
    overview_plot_x.toolbar.active_scroll = wheelzoomtool

    # ---- axes
    overview_plot_x.add_layout(LinearAxis(axis_label="Coordinate X, pix"),
                               place="below")
    overview_plot_x.add_layout(LinearAxis(axis_label="Frame",
                                          major_label_orientation="vertical"),
                               place="left")

    # ---- grid lines
    overview_plot_x.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    overview_plot_x.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- rgba image glyph
    overview_plot_x_image_source = ColumnDataSource(
        dict(image=[np.zeros((1, 1), dtype="float32")],
             x=[0],
             y=[0],
             dw=[1],
             dh=[1]))

    overview_plot_x_image_glyph = Image(image="image",
                                        x="x",
                                        y="y",
                                        dw="dw",
                                        dh="dh")
    overview_plot_x.add_glyph(overview_plot_x_image_source,
                              overview_plot_x_image_glyph,
                              name="image_glyph")

    det_y_range = DataRange1d()
    overview_plot_y = Plot(
        title=Title(text="Projections on Y-axis"),
        x_range=det_y_range,
        y_range=frame_range,
        plot_height=400,
        plot_width=400,
        toolbar_location="left",
    )

    # ---- tools
    wheelzoomtool = WheelZoomTool(maintain_focus=False)
    overview_plot_y.toolbar.logo = None
    overview_plot_y.add_tools(
        PanTool(),
        BoxZoomTool(),
        wheelzoomtool,
        ResetTool(),
    )
    overview_plot_y.toolbar.active_scroll = wheelzoomtool

    # ---- axes
    overview_plot_y.add_layout(LinearAxis(axis_label="Coordinate Y, pix"),
                               place="below")
    overview_plot_y.add_layout(LinearAxis(axis_label="Frame",
                                          major_label_orientation="vertical"),
                               place="left")

    # ---- grid lines
    overview_plot_y.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    overview_plot_y.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- rgba image glyph
    overview_plot_y_image_source = ColumnDataSource(
        dict(image=[np.zeros((1, 1), dtype="float32")],
             x=[0],
             y=[0],
             dw=[1],
             dh=[1]))

    overview_plot_y_image_glyph = Image(image="image",
                                        x="x",
                                        y="y",
                                        dw="dw",
                                        dh="dh")
    overview_plot_y.add_glyph(overview_plot_y_image_source,
                              overview_plot_y_image_glyph,
                              name="image_glyph")

    def frame_button_group_callback(_active):
        update_overview_plot()

    frame_button_group = RadioButtonGroup(labels=["Frames", "Omega"], active=0)
    frame_button_group.on_click(frame_button_group_callback)

    roi_avg_plot = Plot(
        x_range=DataRange1d(),
        y_range=DataRange1d(),
        plot_height=IMAGE_H * 3,
        plot_width=IMAGE_W * 3,
        toolbar_location="left",
    )

    # ---- tools
    roi_avg_plot.toolbar.logo = None

    # ---- axes
    roi_avg_plot.add_layout(LinearAxis(), place="below")
    roi_avg_plot.add_layout(LinearAxis(major_label_orientation="vertical"),
                            place="left")

    # ---- grid lines
    roi_avg_plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    roi_avg_plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    roi_avg_plot_line_source = ColumnDataSource(dict(x=[], y=[]))
    roi_avg_plot.add_glyph(roi_avg_plot_line_source,
                           Line(x="x", y="y", line_color="steelblue"))

    cmap_dict = {
        "gray": Greys256,
        "gray_reversed": Greys256[::-1],
        "plasma": Plasma256,
        "cividis": Cividis256,
    }

    def colormap_callback(_attr, _old, new):
        image_glyph.color_mapper = LinearColorMapper(palette=cmap_dict[new])
        overview_plot_x_image_glyph.color_mapper = LinearColorMapper(
            palette=cmap_dict[new])
        overview_plot_y_image_glyph.color_mapper = LinearColorMapper(
            palette=cmap_dict[new])

    colormap = Select(title="Colormap:", options=list(cmap_dict.keys()))
    colormap.on_change("value", colormap_callback)
    colormap.value = "plasma"

    radio_button_group = RadioButtonGroup(labels=["nb", "nb_bi"], active=0)

    STEP = 1

    # ---- colormap auto toggle button
    def auto_toggle_callback(state):
        if state:
            display_min_spinner.disabled = True
            display_max_spinner.disabled = True
        else:
            display_min_spinner.disabled = False
            display_max_spinner.disabled = False

        update_image()

    auto_toggle = Toggle(label="Auto Range",
                         active=True,
                         button_type="default")
    auto_toggle.on_click(auto_toggle_callback)

    # ---- colormap display max value
    def display_max_spinner_callback(_attr, _old_value, new_value):
        display_min_spinner.high = new_value - STEP
        image_glyph.color_mapper.high = new_value

    display_max_spinner = Spinner(
        title="Maximal Display Value:",
        low=0 + STEP,
        value=1,
        step=STEP,
        disabled=auto_toggle.active,
    )
    display_max_spinner.on_change("value", display_max_spinner_callback)

    # ---- colormap display min value
    def display_min_spinner_callback(_attr, _old_value, new_value):
        display_max_spinner.low = new_value + STEP
        image_glyph.color_mapper.low = new_value

    display_min_spinner = Spinner(
        title="Minimal Display Value:",
        high=1 - STEP,
        value=0,
        step=STEP,
        disabled=auto_toggle.active,
    )
    display_min_spinner.on_change("value", display_min_spinner_callback)

    def hkl_button_callback():
        index = index_spinner.value
        setup_type = "nb_bi" if radio_button_group.active else "nb"
        h, k, l = calculate_hkl(det_data, index, setup_type)
        image_source.data.update(h=[h], k=[k], l=[l])

    hkl_button = Button(label="Calculate hkl (slow)")
    hkl_button.on_click(hkl_button_callback)

    selection_list = TextAreaInput(rows=7)

    def selection_button_callback():
        nonlocal roi_selection
        selection = [
            int(np.floor(det_x_range.start)),
            int(np.ceil(det_x_range.end)),
            int(np.floor(det_y_range.start)),
            int(np.ceil(det_y_range.end)),
            int(np.floor(frame_range.start)),
            int(np.ceil(frame_range.end)),
        ]

        filename_id = filelist.value[-8:-4]
        if filename_id in roi_selection:
            roi_selection[f"{filename_id}"].append(selection)
        else:
            roi_selection[f"{filename_id}"] = [selection]

        selection_list.value = str(roi_selection)

    selection_button = Button(label="Add selection")
    selection_button.on_click(selection_button_callback)

    # Final layout
    layout_image = column(
        gridplot([[proj_v, None], [plot, proj_h]], merge_tools=False),
        row(index_spinner))
    colormap_layout = column(colormap, auto_toggle, display_max_spinner,
                             display_min_spinner)
    hkl_layout = column(radio_button_group, hkl_button)

    layout_overview = column(
        gridplot(
            [[overview_plot_x, overview_plot_y]],
            toolbar_options=dict(logo=None),
            merge_tools=True,
        ),
        frame_button_group,
    )

    tab_layout = row(
        column(
            upload_div,
            upload_button,
            filelist,
            layout_image,
            row(colormap_layout, hkl_layout),
        ),
        column(
            roi_avg_plot,
            layout_overview,
            row(selection_button, selection_list),
        ),
    )

    return Panel(child=tab_layout, title="Data Viewer")
Exemplo n.º 15
0
class EeghdfBrowser:
    """
    take an hdfeeg file and allow for browsing of the EEG signal

    just use the raw hdf file and conventions for now

    """
    def __init__(
        self,
        eeghdf_file,
        page_width_seconds=10.0,
        start_seconds=-1,
        montage="trace",
        montage_options={},
        yscale=1.0,
        plot_width=950,
        plot_height=600,
    ):
        """
        @eegfile is an eeghdf.Eeghdf() class instance representing the file
        @montage is either a string in the standard list or a montageview factory
        @eeghdf_file - an eeghdf.Eeeghdf instance
        @page_width_seconds = how big to make the view in seconds
        @montage - montageview (class factory) OR a string that identifies a default montage (may want to change this to a factory function 
        @start_seconds - center view on this point in time

        BTW 'trace' is what NK calls its 'as recorded' montage - might be better to call 'raw', 'default' or 'as recorded'
        """

        self.eeghdf_file = eeghdf_file
        self.update_eeghdf_file(eeghdf_file, montage, montage_options)

        # display related
        self.page_width_seconds = page_width_seconds

        ## bokeh related

        self.page_width_secs = page_width_seconds
        if start_seconds < 0:
            self.loc_sec = (
                page_width_seconds / 2.0
            )  # default location in file by default at start if possible
        else:
            self.loc_sec = start_seconds

        # self.init_kwargs = kwargs

        # other ones
        self.yscale = yscale
        self.ui_plot_width = plot_width
        self.ui_plot_height = plot_height

        self.bk_handle = None
        self.fig = None

        self.update_title()

        self.num_rows, self.num_samples = self.signals.shape
        self.line_glyphs = []  # not used?
        self.multi_line_glyph = None

        self.ch_start = 0
        self.ch_stop = self.current_montage_instance.shape[0]

        ####### set up filter cache: first try
        self.current_hp_filter = None
        self.current_lp_filter = None
        self._highpass_cache = OrderedDict()

        self._highpass_cache["None"] = None

        self._highpass_cache["0.1 Hz"] = esfilters.fir_highpass_firwin_ff(
            self.fs, cutoff_freq=0.1, numtaps=int(self.fs))

        self._highpass_cache["0.3 Hz"] = esfilters.fir_highpass_firwin_ff(
            self.fs, cutoff_freq=0.3, numtaps=int(self.fs))

        # ff = esfilters.fir_highpass_remez_zerolag(fs=self.fs, cutoff_freq=1.0, transition_width=0.5, numtaps=int(2*self.fs))
        ff = esfilters.fir_highpass_firwin_ff(fs=self.fs,
                                              cutoff_freq=1.0,
                                              numtaps=int(2 * self.fs))
        self._highpass_cache["1 Hz"] = ff
        # ff = esfilters.fir_highpass_remez_zerolag(fs=self.fs, cutoff_freq=5.0, transition_width=2.0, numtaps=int(0.2*self.fs))
        ff = esfilters.fir_highpass_firwin_ff(fs=self.fs,
                                              cutoff_freq=5.0,
                                              numtaps=int(0.2 * self.fs))
        self._highpass_cache["5 Hz"] = ff

        firstkey = "0.3 Hz"  # list(self._highpass_cache.keys())[0]
        self.current_hp_filter = self._highpass_cache[firstkey]

        self._lowpass_cache = OrderedDict()
        self._lowpass_cache["None"] = None
        self._lowpass_cache["15 Hz"] = esfilters.fir_lowpass_firwin_ff(
            fs=self.fs, cutoff_freq=15.0, numtaps=int(self.fs / 2.0))
        self._lowpass_cache["30 Hz"] = esfilters.fir_lowpass_firwin_ff(
            fs=self.fs, cutoff_freq=30.0, numtaps=int(self.fs / 4.0))
        self._lowpass_cache["50 Hz"] = esfilters.fir_lowpass_firwin_ff(
            fs=self.fs, cutoff_freq=50.0, numtaps=int(self.fs / 4.0))
        self._lowpass_cache["70 Hz"] = esfilters.fir_lowpass_firwin_ff(
            fs=self.fs, cutoff_freq=70.0, numtaps=int(self.fs / 4.0))

        self._notch_filter = esfilters.notch_filter_iir_ff(notch_freq=60.0,
                                                           fs=self.fs,
                                                           Q=60)
        self.current_notch_filter = None

    @property
    def signals(self):
        return self.eeghdf_file.phys_signals

    def update_eeghdf_file(self,
                           eeghdf_file,
                           montage="trace",
                           montage_options={}):
        self.eeghdf_file = eeghdf_file
        hdf = eeghdf_file.hdf
        rec = hdf["record-0"]
        self.fs = rec.attrs["sample_frequency"]
        # self.signals = rec['signals']
        blabels = rec["signal_labels"]  # byte labels
        # self.electrode_labels = [str(ss,'ascii') for ss in blabels]
        self.electrode_labels = eeghdf_file.electrode_labels
        # fill in any missing ones
        if len(self.electrode_labels) < eeghdf_file.phys_signals.shape[0]:
            d = eeghdf_file.phys_signals.shape[0] - len(self.electrode_labels)
            ll = len(self.electrode_labels)
            suppl = [str(ii) for ii in range(ll, ll + d)]
            self.electrode_labels += suppl
            print("extending electrode lables:", suppl)

        # reference labels are used for montages, since this is an eeghdf file, it can provide these

        self.ref_labels = eeghdf_file.shortcut_elabels

        if not montage_options:
            # then use builtins and/or ones in the file
            montage_options = montageview.MONTAGE_BUILTINS.copy()
            # print('starting build of montage options', montage_options)

            # montage_options = eeghdf_file.get_montages()

        # defines self.current_montage_instance
        self.current_montage_instance = None
        if type(montage) == str:  # then we have some work to do
            if montage in montage_options:

                self.current_montage_instance = montage_options[montage](
                    self.ref_labels)
            else:
                raise Exception("unrecognized montage: %s" % montage)
        else:
            if montage:  # is a class
                self.current_montage_instance = montage(self.ref_labels)
                montage_options[self.current_montage_instance.name] = montage
            else:  # use default

                self.current_montage_instance = montage_options[0](
                    self.ref_labels)

        assert self.current_montage_instance
        try:  # to update ui display
            self.ui_montage_dropdown.value = self.current_montage_instance.name
        except AttributeError:
            # guess is not yet instantiated
            pass

        self.montage_options = montage_options  # save the montage_options for later
        self.update_title()
        # note this does not do any plotting or update the plot

    def update_title(self):
        self.title = "hdf %s - montage: %s" % (
            self.eeghdf_file.hdf.filename,
            self.current_montage_instance.full_name
            if self.current_montage_instance else "",
        )

    #         if showchannels=='all':
    #             self.ch_start = 0  # change this to a list of channels for fancy slicing
    #             if montage:
    #                 self.ch_stop = montage.shape[0] # all the channels in the montage
    #             self.ch_stop = signals.shape[0] # all the channels in the original signal
    #         else:
    #             self.ch_start, self.ch_stop = showchannels
    #         self.num_rows, self.num_samples = signals.shape
    #         self.line_glyphs = []
    #         self.multi_line_glyph = None

    def plot(self):
        """create a Bokeh figure to hold EEG"""
        self.fig = self.show_montage_centered(
            self.signals,
            self.loc_sec,
            page_width_sec=self.page_width_secs,
            chstart=0,
            chstop=self.current_montage_instance.shape[0],
            fs=self.fs,
            ylabels=self.current_montage_instance.montage_labels,
            yscale=self.yscale,
            montage=self.current_montage_instance,
        )
        self.fig.xaxis.axis_label = "seconds"
        # make the xgrid mark every second
        self.fig.xgrid.ticker = SingleIntervalTicker(
            interval=1.0)  #  bokeh.models.tickers.SingleIntervalTicker
        return self.fig

    def show_for_bokeh_app(self):
        """try running intside a bokeh app, so don't need notebook stuff"""
        self.plot()

    def bokeh_show(self):
        """
        meant to run in notebook so sets up handles
        """
        self.plot()
        self.register_top_bar_ui()  # create the buttons
        self.bk_handle = bokeh.plotting.show(self.fig, notebook_handle=True)
        self.register_bottom_bar_ui()

    def update(self):
        """
        updates the data in the plot 
        so that it will show up
        can either use bokeh.io.push_notebook()
        or panel.pane.Bokeh(..)
        to handle event loop
        """
        goto_sample = int(self.fs * self.loc_sec)
        page_width_samples = int(self.page_width_secs * self.fs)
        hw = half_width_epoch_sample = int(page_width_samples / 2)
        s0 = limit_sample_check(goto_sample - hw, self.signals)
        s1 = limit_sample_check(goto_sample + hw, self.signals)
        window_samples = s1 - s0
        signal_view = self.signals[:, s0:s1]
        inmontage_view = np.dot(self.current_montage_instance.V.data,
                                signal_view)

        data = inmontage_view[self.ch_start:self.ch_stop, :]  # note transposed
        numRows = inmontage_view.shape[0]
        ########## do filtering here ############
        # start primative filtering
        if self.current_notch_filter:
            for ii in range(numRows):
                data[ii, :] = self.current_notch_filter(data[ii, :])

        if self.current_hp_filter:
            for ii in range(numRows):
                data[ii, :] = self.current_hp_filter(data[ii, :])
        if self.current_lp_filter:
            for ii in range(numRows):
                data[ii, :] = self.current_lp_filter(data[ii, :])

        ## end filtering
        t = (self.page_width_secs * np.arange(window_samples, dtype=float) /
             window_samples)
        t = t + s0 / self.fs  # t = t + start_time
        # t = t[:s1-s0]
        ## this is not quite right if ch_start is not 0
        xs = [t for ii in range(numRows)]
        ys = [
            self.yscale * data[ii, :] + self.ticklocs[ii]
            for ii in range(numRows)
        ]
        # print('len(xs):', len(xs), 'len(ys):', len(ys))

        # is this the best way to update the data? should it be done both at once
        # {'xs':xs, 'ys':ys}
        self.data_source.data.update(dict(xs=xs,
                                          ys=ys))  # could just use equals?
        # old way
        # self.data_source.data['xs'] = xs
        # self.data_source.data['ys'] = ys

        #self.push_notebook()
        # do pane.Bokeh::param.trigger('object') on pane holding EEG waveform plot
        # in notebook updates without a trigger

    def stackplot_t(
        self,
        tarray,
        seconds=None,
        start_time=None,
        ylabels=None,
        yscale=1.0,
        topdown=True,
        **kwargs,
    ):
        """
        will plot a stack of traces one above the other assuming
        @tarray is an nd-array like object with format
        tarray.shape =  numSamples, numRows

        @seconds = with of plot in seconds for labeling purposes (optional)
        @start_time is start time in seconds for the plot (optional)

        @ylabels a list of labels for each row ("channel") in marray
        @yscale with increase (mutiply) the signals in each row by this amount
        """
        data = tarray
        numSamples, numRows = tarray.shape
        # data = np.random.randn(numSamples,numRows) # test data
        # data.shape = numSamples, numRows
        if seconds:
            t = seconds * np.arange(numSamples, dtype=float) / numSamples

            if start_time:
                t = t + start_time
                xlm = (start_time, start_time + seconds)
            else:
                xlm = (0, seconds)

        else:
            t = np.arange(numSamples, dtype=float)
            xlm = (0, numSamples)

        ticklocs = []
        if not "plot_width" in kwargs:
            kwargs["plot_width"] = (
                self.ui_plot_width
            )  # 950  # a default width that is wider but can just fit in jupyter, not sure if plot_width is preferred
        if not "plot_height" in kwargs:
            kwargs["plot_height"] = self.ui_plot_height

        if not self.fig:
            # print('creating figure')
            # bokeh.plotting.figure creases a subclass of plot
            fig = bokeh.plotting.figure(
                title=self.title,
                # tools="pan,box_zoom,reset,previewsave,lasso_select,ywheel_zoom",
                tools="pan,box_zoom,reset,lasso_select,ywheel_zoom",
                **kwargs,
            )  # subclass of Plot that simplifies plot creation
            self.fig = fig

        ## xlim(*xlm)
        # xticks(np.linspace(xlm, 10))
        dmin = data.min()
        dmax = data.max()
        dr = (dmax - dmin) * 0.7  # Crowd them a bit.
        y0 = dmin
        y1 = (numRows - 1) * dr + dmax
        ## ylim(y0, y1)

        ticklocs = [ii * dr for ii in range(numRows)]
        bottom = -dr / 0.7
        top = (numRows - 1) * dr + dr / 0.7
        self.y_range = Range1d(bottom, top)
        self.fig.y_range = self.y_range

        if topdown == True:
            ticklocs.reverse()  # inplace

        # print("ticklocs:", ticklocs)

        offsets = np.zeros((numRows, 2), dtype=float)
        offsets[:, 1] = ticklocs
        self.ticklocs = ticklocs
        self.time = t
        ## segs = []
        # note could also duplicate time axis then use p.multi_line
        # line_glyphs = []
        # for ii in range(numRows):
        #     ## segs.append(np.hstack((t[:, np.newaxis], yscale * data[:, i, np.newaxis])))
        #     line_glyphs.append(
        #         fig.line(t[:],yscale * data[:, ii] + offsets[ii, 1] ) # adds line glyphs to figure
        #     )

        #     # print("segs[-1].shape:", segs[-1].shape)
        #     ##ticklocs.append(i * dr)
        # self.line_glyphs = line_glyphs

        ########## do filtering here ############
        # start primative filtering
        # remember we are in the stackplot_t so channels and samples are flipped -- !!! eliminate this junk
        if self.current_notch_filter:
            for ii in range(numRows):
                data[ii, :] = self.current_notch_filter(data[ii, :])

        if self.current_hp_filter:
            # print("doing filtering")
            for ii in range(numRows):
                data[:, ii] = self.current_hp_filter(data[:, ii])

        if self.current_lp_filter:
            for ii in range(numRows):
                data[ii, :] = self.current_lp_filter(data[ii, :])

        ## end filtering

        ## instead build a data_dict and use datasource with multi_line
        xs = [t for ii in range(numRows)]
        ys = [yscale * data[:, ii] + ticklocs[ii] for ii in range(numRows)]

        self.multi_line_glyph = self.fig.multi_line(
            xs=xs, ys=ys)  # , line_color='firebrick')
        self.data_source = self.multi_line_glyph.data_source

        # set the yticks to use axes coords on the y axis
        if not ylabels:
            ylabels = ["%d" % ii for ii in range(numRows)]
        ylabel_dict = dict(zip(ticklocs, ylabels))
        # print('ylabel_dict:', ylabel_dict)
        self.fig.yaxis.ticker = FixedTicker(
            ticks=ticklocs)  # can also short cut to give list directly
        self.fig.yaxis.formatter = FuncTickFormatter(code="""
            var labels = %s;
            return labels[tick];
        """ % ylabel_dict)
        ## ax.set_yticklabels(ylabels)

        ## xlabel('time (s)')

        return self.fig

    def update_plot_after_montage_change(self):

        self.fig.title.text = self.title
        goto_sample = int(self.fs * self.loc_sec)
        page_width_samples = int(self.page_width_secs * self.fs)

        hw = half_width_epoch_sample = int(page_width_samples / 2)
        s0 = limit_sample_check(goto_sample - hw, self.signals)
        s1 = limit_sample_check(goto_sample + hw, self.signals)

        window_samples = s1 - s0
        signal_view = self.signals[:, s0:s1]
        inmontage_view = np.dot(self.current_montage_instance.V.data,
                                signal_view)
        self.ch_start = 0
        self.ch_stop = inmontage_view.shape[0]

        numRows = inmontage_view.shape[0]  # ???
        # print('numRows: ', numRows)

        data = inmontage_view[self.ch_start:self.ch_stop, :]  # note transposed
        # really just need to reset the labels

        ticklocs = []

        ## xlim(*xlm)
        # xticks(np.linspace(xlm, 10))
        dmin = data.min()
        dmax = data.max()
        dr = (dmax - dmin) * 0.7  # Crowd them a bit.
        y0 = dmin
        y1 = (numRows - 1) * dr + dmax
        ## ylim(y0, y1)

        ticklocs = [ii * dr for ii in range(numRows)]
        ticklocs.reverse()  # inplace
        bottom = -dr / 0.7
        top = (numRows - 1) * dr + dr / 0.7
        self.y_range.start = bottom
        self.y_range.end = top
        # self.fig.y_range = Range1d(bottom, top)

        # print("ticklocs:", ticklocs)

        offsets = np.zeros((numRows, 2), dtype=float)
        offsets[:, 1] = ticklocs
        self.ticklocs = ticklocs
        # self.time = t

        ylabels = self.current_montage_instance.montage_labels
        ylabel_dict = dict(zip(ticklocs, ylabels))
        # print('ylabel_dict:', ylabel_dict)
        self.fig.yaxis.ticker = FixedTicker(
            ticks=ticklocs)  # can also short cut to give list directly
        self.fig.yaxis.formatter = FuncTickFormatter(code="""
            var labels = %s;
            return labels[tick];
        """ % ylabel_dict)

        ## experiment with clearing the data source
        # self.data_source.data.clear() # vs .update() ???

    def stackplot(
        self,
        marray,
        seconds=None,
        start_time=None,
        ylabels=None,
        yscale=1.0,
        topdown=True,
        **kwargs,
    ):
        """
        will plot a stack of traces one above the other assuming
        @marray contains the data you want to plot
        marray.shape = numRows, numSamples

        @seconds = with of plot in seconds for labeling purposes (optional)
        @start_time is start time in seconds for the plot (optional)

        @ylabels a list of labels for each row ("channel") in marray
        @yscale with increase (mutiply) the signals in each row by this amount
        """
        tarray = np.transpose(marray)
        return self.stackplot_t(
            tarray,
            seconds=seconds,
            start_time=start_time,
            ylabels=ylabels,
            yscale=yscale,
            topdown=True,
            **kwargs,
        )

    def show_epoch_centered(
        self,
        signals,
        goto_sec,
        page_width_sec,
        chstart,
        chstop,
        fs,
        ylabels=None,
        yscale=1.0,
    ):
        """
        @signals array-like object with signals[ch_num, sample_num]
        @goto_sec where to go in the signal to show the feature
        @page_width_sec length of the window to show in secs
        @chstart   which channel to start
        @chstop    which channel to end
        @labels_by_channel
        @yscale
        @fs sample frequency (num samples per second)
        """

        goto_sample = int(fs * goto_sec)
        hw = half_width_epoch_sample = int(page_width_sec * fs / 2)

        # plot epochs of width page_width_sec centered on (multiples in DE)
        ch0, ch1 = chstart, chstop

        ptepoch = int(page_width_sec * fs)

        s0 = limit_sample_check(goto_sample - hw, signals)
        s1 = limit_sample_check(goto_sample + hw, signals)
        duration = (s1 - s0) / fs
        start_time_sec = s0 / fs

        return self.stackplot(
            signals[ch0:ch1, s0:s1],
            start_time=start_time_sec,
            seconds=duration,
            ylabels=ylabels[ch0:ch1],
            yscale=yscale,
        )

    def show_montage_centered(
        self,
        signals,
        goto_sec,
        page_width_sec,
        chstart,
        chstop,
        fs,
        ylabels=None,
        yscale=1.0,
        montage=None,
        topdown=True,
        **kwargs,
    ):
        """
        plot an eeg segment using current montage, center the plot at @goto_sec
        with @page_width_sec shown

        @signals array-like object with signals[ch_num, sample_num]

        @goto_sec where to go in the signal to show the feature
        @page_width_sec length of the window to show in secs
        @chstart   which channel to start
        @chstop    which channel to end

        @fs sample frequency (num samples per second)

        @ylabels a list of labels for each row ("channel") in marray
        @yscale with increase (mutiply) the signals in each row by this amount
        @montage instance 

        """

        goto_sample = int(fs * goto_sec)
        hw = half_width_epoch_sample = int(page_width_sec * fs / 2)

        # plot epochs of width page_width_sec centered on (multiples in DE)
        ch0, ch1 = chstart, chstop

        ptepoch = int(page_width_sec * fs)

        s0 = limit_sample_check(goto_sample - hw, signals)
        s1 = limit_sample_check(goto_sample + hw, signals)
        duration_sec = (s1 - s0) / fs
        start_time_sec = s0 / fs

        # signals[ch0:ch1, s0:s1]
        signal_view = signals[:, s0:s1]
        inmontage_view = np.dot(montage.V.data, signal_view)
        rlabels = montage.montage_labels

        data = inmontage_view[chstart:chstop, :]
        numRows, numSamples = data.shape
        # data = np.random.randn(numSamples,numRows) # test data
        # data.shape =  numRows, numSamples

        t = duration_sec * np.arange(numSamples, dtype=float) / numSamples

        t = t + start_time_sec  # shift over
        xlm = (start_time_sec, start_time_sec + duration_sec)

        ticklocs = []
        if not "plot_width" in kwargs:
            kwargs["plot_width"] = (
                self.ui_plot_width
            )  # 950  # a default width that is wider but can just fit in jupyter, not sure if plot_width is preferred
        if not "plot_height" in kwargs:
            kwargs["plot_height"] = self.ui_plot_height

        if not self.fig:
            # print('creating figure')
            fig = bokeh.plotting.figure(
                title=self.title,
                # tools="pan,box_zoom,reset,previewsave,lasso_select,ywheel_zoom",
                #tools="pan,box_zoom,reset,lasso_select,ywheel_zoom",
                tools="crosshair",
                **kwargs,
            )  # subclass of Plot that simplifies plot creation
            self.fig = fig

        ## xlim(*xlm)
        # xticks(np.linspace(xlm, 10))
        dmin = data.min()
        dmax = data.max()
        dr = (dmax - dmin) * 0.7  # Crowd them a bit.
        y0 = dmin
        y1 = (numRows - 1) * dr + dmax
        ## ylim(y0, y1)

        ticklocs = [ii * dr for ii in range(numRows)]
        bottom = -dr / 0.7
        top = (numRows - 1) * dr + dr / 0.7
        self.y_range = Range1d(bottom, top)
        self.fig.y_range = self.y_range

        if topdown == True:
            ticklocs.reverse()  # inplace

        # print("ticklocs:", ticklocs)

        offsets = np.zeros((numRows, 2), dtype=float)
        offsets[:, 1] = ticklocs
        self.ticklocs = ticklocs
        self.time = t
        ## segs = []
        # note could also duplicate time axis then use p.multi_line
        # line_glyphs = []
        # for ii in range(numRows):
        #     ## segs.append(np.hstack((t[:, np.newaxis], yscale * data[:, i, np.newaxis])))
        #     line_glyphs.append(
        #         fig.line(t[:],yscale * data[:, ii] + offsets[ii, 1] ) # adds line glyphs to figure
        #     )

        #     # print("segs[-1].shape:", segs[-1].shape)
        #     ##ticklocs.append(i * dr)
        # self.line_glyphs = line_glyphs

        ########## do filtering here ############
        # start primative filtering
        # remember we are in the stackplot_t so channels and samples are flipped -- !!! eliminate this junk
        if self.current_notch_filter:
            for ii in range(numRows):
                data[ii, :] = self.current_notch_filter(data[ii, :])

        if self.current_hp_filter:
            # print("doing filtering")
            for ii in range(numRows):
                data[ii, :] = self.current_hp_filter(data[ii, :])

        if self.current_lp_filter:
            for ii in range(numRows):
                data[ii, :] = self.current_lp_filter(data[ii, :])

        ## end filtering

        ## instead build a data_dict and use datasource with multi_line
        xs = [t for ii in range(numRows)]
        ys = [yscale * data[ii, :] + ticklocs[ii] for ii in range(numRows)]

        self.multi_line_glyph = self.fig.multi_line(
            xs=xs, ys=ys)  # , line_color='firebrick')
        self.data_source = self.multi_line_glyph.data_source

        # set the yticks to use axes coords on the y axis
        if not ylabels:
            ylabels = ["%d" % ii for ii in range(numRows)]
        ylabel_dict = dict(zip(ticklocs, ylabels))
        # print('ylabel_dict:', ylabel_dict)
        self.fig.yaxis.ticker = FixedTicker(
            ticks=ticklocs)  # can also short cut to give list directly
        self.fig.yaxis.formatter = FuncTickFormatter(code="""
            var labels = %s;
            return labels[tick];
        """ % ylabel_dict)
        return self.fig

    def register_top_bar_ui(self):

        # mlayout = ipywidgets.Layout()
        # mlayout.width = "15em"
        self.ui_montage_dropdown = Select(
            # options={'One': 1, 'Two': 2, 'Three': 3},
            options=self.montage_options.keys(),  # or .montage_optins.keys()
            value=self.current_montage_instance.name,
            title="Montage:",
            # layout=mlayout, # set width to "15em"
        )

        def on_dropdown_change(attr, oldvalue, newvalue, parent=self):
            print(
                f"on_dropdown_change: {attr}, {oldvalue}, {newvalue}, {parent}"
            )
            if change["name"] == "value":  # the value changed
                if change["new"] != change["old"]:
                    # print('*** should change the montage to %s from %s***' % (change['new'], change['old']))
                    parent.update_montage(
                        change["new"]
                    )  # change to the montage keyed by change['new']
                    parent.update_plot_after_montage_change()
                    parent.update()  #

        self.ui_montage_dropdown.on_change("value", on_dropdown_change)

        # flayout = ipywidgets.Layout()
        # flayout.width = "12em"
        self.ui_low_freq_filter_dropdown = ipywidgets.Dropdown(
            # options = ['None', '0.1 Hz', '0.3 Hz', '1 Hz', '5 Hz', '15 Hz',
            #           '30 Hz', '50 Hz', '100 Hz', '150Hz'],
            options=self._highpass_cache.keys(),
            value="0.3 Hz",
            description="LF",
            layout=flayout,
        )

        def lf_dropdown_on_change(change, parent=self):
            # print('change observed: %s' % pprint.pformat(change))
            if change["name"] == "value":  # the value changed
                if change["new"] != change["old"]:
                    # print('*** should change the filter to %s from %s***' % (change['new'], change['old']))
                    parent.current_hp_filter = parent._highpass_cache[
                        change["new"]]
                    parent.update()  #

        self.ui_low_freq_filter_dropdown.observe(lf_dropdown_on_change)

        ###

        self.ui_high_freq_filter_dropdown = ipywidgets.Dropdown(
            # options = ['None', '15 Hz', '30 Hz', '50 Hz', '70Hz', '100 Hz', '150Hz', '300 Hz'],
            options=self._lowpass_cache.keys(),
            # value = '70Hz',
            description="HF",
            layout=flayout,
        )

        def hf_dropdown_on_change(change, parent=self):
            if change["name"] == "value":  # the value changed
                if change["new"] != change["old"]:
                    # print('*** should change the filter to %s from %s***' % (change['new'], change['old']))
                    self.current_lp_filter = self._lowpass_cache[change["new"]]
                    self.update()  #

        self.ui_high_freq_filter_dropdown.observe(hf_dropdown_on_change)

        def go_to_handler(change, parent=self):
            # print("change:", change)
            if change["name"] == "value":
                self.loc_sec = change["new"]
                self.update()

        self.ui_notch_option = ipywidgets.Checkbox(value=False,
                                                   description="60Hz Notch",
                                                   disabled=False)

        def notch_change(change):
            if change["name"] == "value":
                if change["new"]:
                    self.current_notch_filter = self._notch_filter
                else:
                    self.current_notch_filter = None
                self.update()

        self.ui_notch_option.observe(notch_change)

        self.ui_gain_bounded_float = ipywidgets.BoundedFloatText(
            value=1.0,
            min=0.001,
            max=1000.0,
            step=0.1,
            description="gain",
            disabled=False,
            continuous_update=False,  # only trigger when done
            layout=flayout,
        )

        def ui_gain_on_change(change, parent=self):
            if change["name"] == "value":
                if change["new"] != change["old"]:
                    self.yscale = float(change["new"])
                    self.update()

        self.ui_gain_bounded_float.observe(ui_gain_on_change)
        top_bar_layout = bokeh.layouts.row(
            self.ui_montage_dropdown,
            self.ui_low_freq_filter_dropdown,
            self.ui_high_freq_filter_dropdown,
            self.ui_notch_option,
            self.ui_gain_bounded_float,
        )
        return top_bar_layout
        # display(
        #     ipywidgets.HBox(
        #         [
        #             self.ui_montage_dropdown,
        #             self.ui_low_freq_filter_dropdown,
        #             self.ui_high_freq_filter_dropdown,
        #             self.ui_notch_option,
        #             self.ui_gain_bounded_float,
        #         ]
        #     )
        # )

    def register_top_bar_ui(self):
        # mlayout = ipywidgets.Layout()
        # mlayout.width = "15em"
        self.ui_montage_dropdown = Select(
            # options={'One': 1, 'Two': 2, 'Three': 3},
            options=list(
                self.montage_options.keys()),  # or .montage_optins.keys()
            value=str(self.current_montage_instance.name),
            title="Montage:",
            # layout=mlayout, # set width to "15em"
        )

        def on_dropdown_change(attr, oldvalue, newvalue, parent=self):
            # print(
            #     f"on_dropdown_change: {repr(attr)}, {repr(oldvalue)}, {repr(newvalue)}, {parent}"
            # )

            parent.update_montage(newvalue)
            parent.update_plot_after_montage_change()
            parent.update()

        self.ui_montage_dropdown.on_change("value", on_dropdown_change)

        # flayout = ipywidgets.Layout()
        # flayout.width = "12em"
        self.ui_low_freq_filter_dropdown = Select(
            # options = ['None', '0.1 Hz', '0.3 Hz', '1 Hz', '5 Hz', '15 Hz',
            #           '30 Hz', '50 Hz', '100 Hz', '150Hz'],
            options=list(self._highpass_cache.keys()),
            value="0.3 Hz",
            title="LF",
            max_width=
            150,  # only problem is that I suspect this is pixels, maybe use panel?css
            # layout=flayout, # see https://panel.holoviz.org/user_guide/Customization.html
        )

        def lf_dropdown_on_change(attr, oldvalue, newvalue, parent=self):
            # print(
            #     f"on_dropdown_change: {repr(attr)}, {repr(oldvalue)}, {repr(newvalue)}, {parent}"
            # )
            parent.current_hp_filter = parent._highpass_cache[newvalue]
            parent.update()  #

        self.ui_low_freq_filter_dropdown.on_change("value",
                                                   lf_dropdown_on_change)

        self.ui_high_freq_filter_dropdown = Select(
            # options = ['None', '15 Hz', '30 Hz', '50 Hz', '70Hz', '100 Hz', '150Hz', '300 Hz'],
            options=list(self._lowpass_cache.keys()),
            value="None",
            title="HF",
            max_width=150,
            # layout=flayout,
        )

        def hf_dropdown_on_change(attr, oldvalue, newvalue, parent=self):
            # print(
            #     f"on_dropdown_change: {repr(attr)}, {repr(oldvalue)}, {repr(newvalue)}, {parent}"
            # )
            # if change["name"] == "value":  # the value changed
            #    if change["new"] != change["old"]:
            #        # print('*** should change the filter to %s from %s***' % (change['new'], change['old']))
            self.current_lp_filter = self._lowpass_cache[newvalue]
            self.update()  #

        self.ui_high_freq_filter_dropdown.on_change("value",
                                                    hf_dropdown_on_change)

        self.ui_notch_option = CheckboxGroup(
            labels=["60Hz Notch"]
            #, "50Hz Notch"], max_width=100,  # disabled=False
        )

        def notch_change(newvalue, parent=self):
            #print(f"on_dropdown_change: {repr(newvalue)}, {parent}")
            if newvalue == [0]:
                self.current_notch_filter = self._notch_filter
            elif newvalue == []:
                self.current_notch_filter = None
            self.update()

        self.ui_notch_option.on_click(notch_change)

        self.ui_gain_bounded_float = Spinner(
            value=1.0,
            # min=0.001,
            # max=1000.0,
            step=
            0.1,  # Interval(interval_type: (Int, Float), start, end, default=None, help=None)
            # page_step_multiplier=2.0, # may be supported in bokeh 2.2
            title="gain",
            # value_throtted=(float|int)
            # disabled=False,
            # continuous_update=False,  # only trigger when done
            # layout=flayout,
            width=100,
        )

        def ui_gain_on_change(attr, oldvalue, newvalue, parent=self):
            # print(
            #     f"ui_gain_on_change: {repr(oldvalue)},\n {repr(newvalue)}, {repr(type(newvalue))},{parent}"
            # )

            self.yscale = float(newvalue)
            self.update()

        self.ui_gain_bounded_float.on_change("value", ui_gain_on_change)

        self.top_bar_layout = bokeh.layouts.row(
            self.ui_montage_dropdown,
            self.ui_low_freq_filter_dropdown,
            self.ui_high_freq_filter_dropdown,
            self.ui_gain_bounded_float,
            self.ui_notch_option,
        )
        return self.top_bar_layout

    def _limit_time_check(self, candidate):
        if candidate > self.eeghdf_file.duration_seconds:
            return float(self.eeghdf_file.duration_seconds)
        if candidate < 0:
            return 0.0
        return candidate

    def register_bottom_bar_ui(self):
        # self.ui_buttonf = ipywidgets.Button(description="go forward 10s")
        self.ui_buttonf = Button(label="go forward 10s")
        # self.ui_buttonback = ipywidgets.Button(description="go backward 10s")
        self.ui_buttonback = Button(label="go backward 10s")
        # self.ui_buttonf1 = ipywidgets.Button(description="forward 1 s")
        self.ui_buttonf1 = Button(label="forward 1 s")
        # self.ui_buttonback1 = ipywidgets.Button(description="back 1 s")
        self.ui_buttonback1 = Button(label="back 1 s")

        # could put goto input here

        def go_forward(b, parent=self):
            #print(b, parent)
            self.loc_sec = self._limit_time_check(self.loc_sec + 10)
            self.update()

        self.ui_buttonf.on_click(go_forward)

        def go_backward(b):
            self.loc_sec = self._limit_time_check(self.loc_sec - 10)
            self.update()

        self.ui_buttonback.on_click(go_backward)

        def go_forward1(b, parent=self):
            self.loc_sec = self._limit_time_check(self.loc_sec + 1)
            self.update()

        self.ui_buttonf1.on_click(go_forward1)

        def go_backward1(b, parent=self):
            self.loc_sec = self._limit_time_check(self.loc_sec - 1)
            self.update()

        self.ui_buttonback1.on_click(go_backward1)

        #self.ui_current_location = FloatInput...  # keep in sync with jslink?
        def go_to_handler(attr, oldvalue, newvalue, parent=self):
            # print("change:", change)
            self.loc_sec = self._limit_time_check(float(newvalue))
            self.update()

        self.ui_bottom_bar_layout = bokeh.layouts.row(
            self.ui_buttonback,
            self.ui_buttonf,
            self.ui_buttonback1,
            self.ui_buttonf1,
        )
        return self.ui_bottom_bar_layout
        # print('displayed buttons')

    def update_montage(self, montage_name):
        Mv = self.montage_options[montage_name]
        new_montage = Mv(self.ref_labels)
        self.current_montage_instance = new_montage
        self.ch_start = 0
        self.ch_stop = new_montage.shape[0]
        self.update_title()
Exemplo n.º 16
0
def create(palm):
    # Reference etof channel
    def ref_etof_channel_textinput_callback(_attr, _old_value, new_value):
        palm.channels["0"] = new_value

    ref_etof_channel_textinput = TextInput(title="Reference eTOF channel:",
                                           value=palm.channels["0"])
    ref_etof_channel_textinput.on_change("value",
                                         ref_etof_channel_textinput_callback)

    # Streaking etof channel
    def str_etof_channel_textinput_callback(_attr, _old_value, new_value):
        palm.channels["1"] = new_value

    str_etof_channel_textinput = TextInput(title="Streaking eTOF channel:",
                                           value=palm.channels["1"])
    str_etof_channel_textinput.on_change("value",
                                         str_etof_channel_textinput_callback)

    # XFEL energy value spinner
    def xfel_energy_spinner_callback(_attr, old_value, new_value):
        if new_value > 0:
            palm.xfel_energy = new_value
        else:
            xfel_energy_spinner.value = old_value

    xfel_energy_spinner = Spinner(title="XFEL energy, eV:",
                                  value=palm.xfel_energy,
                                  step=0.1)
    xfel_energy_spinner.on_change("value", xfel_energy_spinner_callback)

    # Binding energy value spinner
    def binding_energy_spinner_callback(_attr, old_value, new_value):
        if new_value > 0:
            palm.binding_energy = new_value
        else:
            binding_energy_spinner.value = old_value

    binding_energy_spinner = Spinner(title="Binding energy, eV:",
                                     value=palm.binding_energy,
                                     step=0.1)
    binding_energy_spinner.on_change("value", binding_energy_spinner_callback)

    # Zero drift tube value spinner
    def zero_drift_spinner_callback(_attr, old_value, new_value):
        if new_value > 0:
            palm.zero_drift_tube = new_value
        else:
            zero_drift_spinner.value = old_value

    zero_drift_spinner = Spinner(title="Zero drift tube, eV:",
                                 value=palm.zero_drift_tube,
                                 step=0.1)
    zero_drift_spinner.on_change("value", zero_drift_spinner_callback)

    tab_layout = column(
        ref_etof_channel_textinput,
        str_etof_channel_textinput,
        Spacer(height=30),
        xfel_energy_spinner,
        binding_energy_spinner,
        zero_drift_spinner,
    )

    return Panel(child=tab_layout, title="Setup")
Exemplo n.º 17
0
win_prob_slider = Slider(start=0.7,
                         end=1,
                         step=0.01,
                         value=win_prob,
                         title="Min. Win/Tie Probability",
                         width=200)
win_prob_slider.on_change('value', update_win_prob)

odds_spinner = Spinner(low=1.01,
                       high=1000,
                       step=0.01,
                       value=1.2,
                       title="Min. Odds",
                       width=200)
odds_spinner.on_change('value', update_odds)

bet_spinner = Spinner(title="Bet ($)",
                      low=5,
                      high=100000,
                      step=5,
                      value=5,
                      width=200)
bet_spinner.on_change('value', update_spinner)

date_range_slider = DateRangeSlider(
    value=(odf['date'].max().date() - datetime.timedelta(days=365),
           odf['date'].max().date()),
    start=odf['date'].min().date(),
    end=odf['date'].max().date())
date_range_slider.on_change('value', update_year)
Exemplo n.º 18
0
def bkapp(doc):
    
### Functions ###

    # functions for user dialogs

    def open_file(ftype):
        root = Tk()
        root.withdraw()
        file = askopenfilename(filetypes=ftype,
                               title='Open File',
                               initialdir=os.getcwd())
        root.destroy()
        return file
    
    def choose_directory():
        root = Tk()
        root.withdraw()
        out_dir = askdirectory()
        root.destroy()
        return out_dir

    def write_output_directory(output_dir):
        root = Tk()
        root.withdraw()
        makeDir = askquestion('Make Directory','Output directory not set. Make directory: '
            +output_dir+'? If not, you\'ll be prompted to change directories.',icon = 'warning')
        root.destroy()
        return makeDir

    def overwrite_file():
        root = Tk()
        root.withdraw()
        overwrite = askquestion('Overwrite File','File already exits. Do you want to overwrite?',icon = 'warning')
        root.destroy()
        return overwrite

    def update_filename():
        filetype = [("Video files", "*.mp4")]
        fname = open_file(filetype)
        if fname:
            #print('Successfully loaded file: '+fname)
            load_data(filename=fname)         

    def change_directory():
        out_dir = choose_directory()
        if out_dir:
            source.data["output_dir"] = [out_dir]
            outDir.text = out_dir
        return out_dir

    # load data from file

    def load_data(filename):
        vidcap = cv2.VideoCapture(filename)
        success,frame = vidcap.read()
        img_tmp,_,__ = cv2.split(frame)
        h,w = np.shape(img_tmp)
        img = np.flipud(img_tmp)
        radio_button_gp.active = 0
        fname = os.path.split(filename)[1]
        input_dir = os.path.split(filename)[0]
        if source.data['output_dir'][0]=='':
            output_dir = os.path.join(input_dir,fname.split('.')[0])
        else:
            output_dir = source.data['output_dir'][0]
        if not os.path.isdir(output_dir):
            makeDir = write_output_directory(output_dir)
            if makeDir=='yes':
                os.mkdir(output_dir)
            else:
                output_dir = change_directory()
        if output_dir:
            source.data = dict(image_orig=[img], image=[img], bin_img=[0],
                x=[0], y=[0], dw=[w], dh=[h], num_contours=[0], roi_coords=[0], 
                img_name=[fname], input_dir=[input_dir], output_dir=[output_dir])
            curr_img = p.select_one({'name':'image'})
            if curr_img:
                p.renderers.remove(curr_img)
            p.image(source=source, image='image', x='x', y='y', dw='dw', dh='dh', color_mapper=cmap,level='image',name='image')      
            p.plot_height=int(h/2)
            p.plot_width=int(w/2)
            #p.add_tools(HoverTool(tooltips=IMG_TOOLTIPS))
            inFile.text = fname
            outDir.text = output_dir
        else:
            print('Cancelled. To continue please set output directory.{:<100}'.format(' '),end="\r")

    # resetting sources for new data or new filters/contours

    def remove_plot():
        source.data["num_contours"]=[0]
        contours_found.text = 'Droplets Detected: 0'
        source_contours.data = dict(xs=[], ys=[])
        source_label.data = dict(x=[], y=[], label=[])

    # apply threshold filter and display binary image

    def apply_filter():
        if source.data['input_dir'][0] == '':
            print('No image loaded! Load image first.{:<100}'.format(' '),end="\r")
        else:
            img = np.squeeze(source.data['image_orig'])
            # remove contours if present
            if source.data["num_contours"]!=[0]:
                remove_plot()
            if radio_button_gp.active == 1:
                thresh = filters.threshold_otsu(img)
                binary = img > thresh
                bin_img = binary*255
                source.data["bin_img"] = [bin_img]
            elif radio_button_gp.active == 2:
                thresh = filters.threshold_isodata(img)
                binary = img > thresh
                bin_img = binary*255
                source.data["bin_img"] = [bin_img]
            elif radio_button_gp.active == 3:
                thresh = filters.threshold_mean(img)
                binary = img > thresh
                bin_img = binary*255
                source.data["bin_img"] = [bin_img]
            elif radio_button_gp.active == 4:
                thresh = filters.threshold_li(img)
                binary = img > thresh
                bin_img = binary*255
                source.data["bin_img"] = [bin_img]
            elif radio_button_gp.active == 5:
                thresh = filters.threshold_yen(img)
                binary = img > thresh
                bin_img = binary*255
                source.data["bin_img"] = [bin_img]
            elif radio_button_gp.active == 6:
                off = offset_spinner.value
                block_size = block_spinner.value
                thresh = filters.threshold_local(img,block_size,offset=off)
                binary = img > thresh
                bin_img = binary*255
                source.data["bin_img"] = [bin_img]
            else:
                bin_img = img
            source.data['image'] = [bin_img]

    # image functions for adjusting the binary image
    
    def close_img():
        if source.data["num_contours"]!=[0]:
            remove_plot()
        if radio_button_gp.active == 0:
            print("Please Select Filter for Threshold{:<100}".format(' '),end="\r")
        else:
            source.data["image"] = source.data["bin_img"]
            img = np.squeeze(source.data['bin_img'])
            closed_img = binary_closing(255-img)*255
            source.data['image'] = [255-closed_img]
            source.data['bin_img'] = [255-closed_img]

    def dilate_img():
        if source.data["num_contours"]!=[0]:
            remove_plot()
        if radio_button_gp.active == 0:
            print("Please Select Filter for Threshold{:<100}".format(' '),end="\r")
        else:
            img = np.squeeze(source.data['bin_img'])
            dilated_img = binary_dilation(255-img)*255
            source.data['image'] = [255-dilated_img]
            source.data['bin_img'] = [255-dilated_img]

    def erode_img():
        if source.data["num_contours"]!=[0]:
            remove_plot()
        if radio_button_gp.active == 0:
            print("Please Select Filter for Threshold{:<100}".format(' '),end="\r")
        else:
            img = np.squeeze(source.data['bin_img'])
            eroded_img = binary_erosion(255-img)*255
            source.data['image'] = [255-eroded_img] 
            source.data['bin_img'] = [255-eroded_img]  

    # the function for identifying closed contours in the image

    def find_contours(level=0.8):
        min_drop_size = contour_rng_slider.value[0]
        max_drop_size = contour_rng_slider.value[1]
        min_dim = 20
        max_dim = 200
        if radio_button_gp.active == 0:
            print("Please Select Filter for Threshold{:<100}".format(' '),end="\r")
        elif source.data['input_dir'][0] == '':
            print('No image loaded! Load image first.{:<100}'.format(' '),end="\r")
        else:
            img = np.squeeze(source.data['bin_img'])
            h,w = np.shape(img)        
            contours = measure.find_contours(img, level)
            length_cnt_x = [cnt[:,1] for cnt in contours if np.shape(cnt)[0] < max_drop_size 
                             and np.shape(cnt)[0] > min_drop_size]
            length_cnt_y = [cnt[:,0] for cnt in contours if np.shape(cnt)[0] < max_drop_size 
                             and np.shape(cnt)[0] > min_drop_size]
            matched_cnt_x = []
            matched_cnt_y = []
            roi_coords = []
            label_text = []
            label_y = np.array([])
            count=0
            for i in range(len(length_cnt_x)):
                cnt_x = length_cnt_x[i]
                cnt_y = length_cnt_y[i]
                width = np.max(cnt_x)-np.min(cnt_x)
                height = np.max(cnt_y)-np.min(cnt_y)
                if width>min_dim and width<max_dim and height>min_dim and height<max_dim:
                    matched_cnt_x.append(cnt_x)
                    matched_cnt_y.append(cnt_y)
                    roi_coords.append([round(width),round(height),round(np.min(cnt_x)),round(h-np.max(cnt_y))])
                    label_text.append(str(int(count)+1))
                    label_y = np.append(label_y,np.max(cnt_y))
                    count+=1
            curr_contours = p.select_one({'name':'overlay'})
            if curr_contours:
                p.renderers.remove(curr_contours)
            #if source.data["num_contours"]==[0]:
                #remove_plot()
                #p.image(source=source, image='image_orig', x=0, y=0, dw=w, dh=h, color_mapper=cmap, name='overlay',level='underlay')      
            source.data["image"] = source.data["image_orig"]
            source.data["num_contours"] = [count]
            #source.data["cnt_x"] = [matched_cnt_x]
            #source.data["cnt_y"] = [matched_cnt_y]
            source.data["roi_coords"] = [roi_coords]
            source_contours.data = dict(xs=matched_cnt_x, ys=matched_cnt_y)
            p.multi_line(xs='xs',ys='ys',source=source_contours, color=(255,127,14),line_width=2, name="contours",level='glyph')
            if len(np.array(roi_coords).shape)<2:
                if len(np.array(roi_coords)) <4:
                    print('No contours found. Try adjusting parameters or filter for thresholding.{:<100}'.format(' '),end="\r")
                    source_label.data = dict(x=[],y=[],label=[])
                else:
                    source_label.data = dict(x=np.array(roi_coords)[2], y=label_y, label=label_text)
            else:
                source_label.data = dict(x=np.array(roi_coords)[:,2], y=label_y, label=label_text)
            contours_found.text = 'Droplets Detected: '+str(len(matched_cnt_x))

    # write the contours and parameters to files

    def export_ROIs():
        if source.data["num_contours"]==[0]:
            print("No Contours Found! Find contours first.{:<100}".format(' '),end="\r")
        else:
            hdr = 'threshold filter,contour min,contour max'
            thresh_filter = radio_button_gp.active
            cnt_min = contour_rng_slider.value[0]
            cnt_max = contour_rng_slider.value[1]
            params = [thresh_filter,cnt_min,cnt_max]
            if radio_button_gp.active == 6:
                off = offset_spinner.value
                block_size = block_spinner.value
                hdr = hdr + ',local offset,local block size'
                params.append(off,block_size)
            params_fname = 'ContourParams.csv'
            params_out = os.path.join(source.data['output_dir'][0],params_fname)
            overwrite = 'no'
            if os.path.exists(params_out):
                overwrite = overwrite_file()
            if overwrite == 'yes' or not os.path.exists(params_out):
                np.savetxt(params_out,np.array([params]),delimiter=',',header=hdr,comments='')
            roi_coords = np.array(source.data["roi_coords"][0])
            out_fname = 'ROI_coords.csv'
            out_fullpath = os.path.join(source.data['output_dir'][0],out_fname)
            if overwrite == 'yes' or not os.path.exists(out_fullpath):
                hdr = 'width,height,x,y'
                np.savetxt(out_fullpath,roi_coords,delimiter=',',header=hdr,comments='')
                print('Successfully saved ROIs coordinates as: '+out_fullpath,end='\r')
                source.data['roi_coords'] = [roi_coords]

    # function for loading previously made files or error handling for going out of order

    def check_ROI_files():
        coords_file = os.path.join(source.data["output_dir"][0],'ROI_coords.csv')
        n_cnt_curr = source.data["num_contours"][0]
        roi_coords_curr = source.data["roi_coords"][0]
        if os.path.exists(coords_file):
            df_tmp=pd.read_csv(coords_file, sep=',')
            roi_coords = np.array(df_tmp.values)
            n_cnt = len(roi_coords)
            if n_cnt_curr != n_cnt or not np.array_equal(roi_coords_curr,roi_coords):
                print('Current ROIs are different from saved ROI file! ROIs from saved file will be used instead and plot updated.',end="\r")
            source.data["num_contours"] = [n_cnt]
            source.data["roi_coords"] = [roi_coords]
            params_file = os.path.join(source.data['output_dir'][0],'ContourParams.csv')
            params_df = pd.read_csv(params_file)
            thresh_ind = params_df["threshold filter"].values[0]
            radio_button_gp.active = int(thresh_ind)
            if thresh_ind == 6:
                offset_spinner.value = int(params_df["local offset"].values[0])
                block_spinner.value = int(params_df["local block size"].values[0])
            contour_rng_slider.value = tuple([int(params_df["contour min"].values[0]),int(params_df["contour max"].values[0])])
            find_contours()
        else:
            print("ROI files not found! Check save directory or export ROIs.{:<100}".format(' '),end="\r")

    # use FFMPEG to crop out regions from original mp4 and save to file

    def create_ROI_movies():
        if source.data['input_dir'][0] == '':
            print('No image loaded! Load image first.{:<100}'.format(' '),end="\r")
        else:
            check_ROI_files()
            side = 100 # for square ROIs, replace first two crop parameters with side & uncomment
            padding = 20
            roi_coords_file = os.path.join(source.data['output_dir'][0],'ROI_coords.csv')
            if source.data["num_contours"]==[0]:
                print("No contours found! Find contours first.{:<100}".format(' '),end="\r")
            elif not os.path.exists(roi_coords_file):
                print("ROI file does not exist! Check save directory or export ROIs.{:<100}".format(' '),end="\r")
            else:
                print('Creating Movies...{:<100}'.format(' '),end="\r")
                pbar = tqdm(total=source.data["num_contours"][0])
                for i in range(source.data["num_contours"][0]):
                    roi_coords = np.array(source.data["roi_coords"][0])
                    inPath = os.path.join(source.data['input_dir'][0],source.data['img_name'][0])
                    #out_fname = source.data['filename'][0].split('.')[0] +'_ROI'+str(i+1)+'.mp4'
                    out_fname = 'ROI'+str(i+1)+'.mp4'
                    outPath = os.path.join(source.data['output_dir'][0],out_fname)
                    #command = f"ffmpeg -i \'{(inPath)}\' -vf \"crop={(roi_coords[i,0]+padding*2)}:{(roi_coords[i,1]+padding*2)}:{(roi_coords[i,2]-padding)}:{(roi_coords[i,3]+padding)}\" -y \'{(outPath)}\'"
                    command = f"ffmpeg -i \'{(inPath)}\' -vf \"crop={side}:{side}:{(roi_coords[i,2]-padding)}:{(roi_coords[i,3]-padding)}\" -y \'{(outPath)}\'"
                    overwrite = 'no'
                    if os.path.exists(outPath):
                        overwrite = overwrite_file()
                    if overwrite == 'yes' or not os.path.exists(outPath):
                        saved = subprocess.check_call(command,shell=True)
                        if saved != 0:
                            print('An error occurred while creating movies! Check terminal window.{:<100}'.format(' '),end="\r")
                    pbar.update()

    # change the display range on images from slider values

    def update_image():
        cmap.low = display_range_slider.value[0]
        cmap.high = display_range_slider.value[1]

    # create statistics files for each mp4 region specific file

    def process_ROIs():
        if source.data['input_dir'][0] == '':
            print('No image loaded! Load image first.{:<100}'.format(' '),end="\r")
        else:
            check_ROI_files()
            hdr = 'roi,time,area,mean,variance,min,max,median,skewness,kurtosis,rawDensity,COMx,COMy'
            cols = hdr.split(',')
            all_stats = np.zeros((0,13))
            n_cnt = source.data["num_contours"][0]
            if n_cnt == 0:
                print("No contours found! Find contours first.{:<100}".format(' '),end="\r")
            for i in range(n_cnt): 
                #in_fname = source.data['filename'][0].split('.')[0] +'_ROI'+str(i+1)+'.mp4'
                in_fname = 'ROI'+str(i+1)+'.mp4'
                inPath = os.path.join(source.data['output_dir'][0],in_fname)
                #out_fname = source.data['filename'][0].split('.')[0] +'_ROI'+str(i+1)+'_stats.csv'
                out_fname = 'stats_ROI'+str(i+1)+'.csv'
                outPath = os.path.join(source.data['output_dir'][0],out_fname)
                if not os.path.exists(inPath):
                    print('ROI movie not found! Create ROI movie first.{:<100}'.format(' '),end="\r")
                    break
                vidcap = cv2.VideoCapture(inPath)
                last_frame = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
                if i==0:
                    pbar = tqdm(total=last_frame*n_cnt)
                success,frame = vidcap.read()
                img_tmp,_,__ = cv2.split(frame)
                h,w = np.shape(img_tmp)
                img = np.zeros((last_frame,h,w))
                img_stats = np.zeros((last_frame,13))
                stats = describe(img_tmp,axis=None)
                median = np.median(img_tmp)
                density = np.sum(img_tmp)
                cx, cy = center_of_mass(img_tmp)
                img_stats[0,0:13] = [i,0,stats.nobs,stats.mean,stats.variance,
                        stats.minmax[0],stats.minmax[1],median,stats.skewness,
                        stats.kurtosis,density,cx,cy]
                img[0,:,:] = np.flipud(img_tmp)
                pbar.update()
                overwrite = 'no'
                if os.path.exists(outPath):
                    overwrite = overwrite_file()
                    if overwrite=='no':
                        pbar.update(last_frame-1)
                if overwrite == 'yes' or not os.path.exists(outPath):
                    for j in range(1,last_frame):
                        vidcap.set(1, j)
                        success,frame = vidcap.read()
                        img_tmp,_,__ = cv2.split(frame)
                        stats = describe(img_tmp,axis=None)
                        t = j*5/60
                        density = np.sum(img_tmp)
                        cx, cy = center_of_mass(img_tmp)
                        median = np.median(img_tmp)
                        img_stats[j,0:13] = [i,t,stats.nobs,stats.mean,stats.variance,
                                stats.minmax[0],stats.minmax[1],median,stats.skewness,
                                stats.kurtosis,density,cx,cy]
                        img[j,:,:] = np.flipud(img_tmp)
                        pbar.update(1)
                    all_stats = np.append(all_stats,img_stats,axis=0)
                    np.savetxt(outPath,img_stats,delimiter=',',header=hdr,comments='')
                if i==(n_cnt-1):
                    df = pd.DataFrame(all_stats,columns=cols)
                    group = df.groupby('roi')
                    for i in range(len(group)):
                        sources_stats[i] = ColumnDataSource(group.get_group(i))

    # load statistics CSVs and first ROI mp4 files and display in plots

    def load_ROI_files():
        if source.data['input_dir'][0] == '':
            print('No image loaded! Load image first.{:<100}'.format(' '),end="\r")
        else:
            check_ROI_files()
            n_cnt = source.data["num_contours"][0]
            basepath = os.path.join(source.data["output_dir"][0],'stats')
            all_files = [basepath+'_ROI'+str(i+1)+'.csv' for i in range(n_cnt)]
            files_exist = [os.path.exists(f) for f in all_files]
            if all(files_exist) and n_cnt != 0:
                df = pd.concat((pd.read_csv(f) for f in all_files), ignore_index=True)
                group = df.groupby('roi')
                OPTIONS = []
                LABELS = []
                pbar = tqdm(total=len(stats)*len(group))
                j=0
                colors_ordered = list(Category20[20])
                idx_reorder = np.append(np.linspace(0,18,10),np.linspace(1,19,10))
                idx = idx_reorder.astype(int)
                colors = [colors_ordered[i] for i in idx]
                for roi, df_roi in group:
                    sources_stats[roi] = ColumnDataSource(df_roi)
                    OPTIONS.append([str(int(roi)+1),'ROI '+(str(int(roi)+1))])
                    LABELS.append('ROI '+str(int(roi)+1))
                    color = colors[j]
                    j+=1
                    if j>=20:
                        j=0
                    for i in range(3,len(df.columns)):
                        name = 'ROI '+str(int(roi)+1)
                        plot_check = p_stats[i-3].select_one({'name':name})
                        if not plot_check:
                            p_stats[i-3].line(x='time',y=str(df.columns[i]),source=sources_stats[roi],
                                name=name,visible=False,line_color=color)
                            p_stats[i-3].xaxis.axis_label = "Time [h]"
                            p_stats[i-3].yaxis.axis_label = str(df.columns[i])
                            p_stats[i-3].add_tools(HoverTool(tooltips=TOOLTIPS))
                            p_stats[i-3].toolbar_location = "left"
                        pbar.update(1)
                ROI_multi_select.options = OPTIONS 
                ROI_multi_select.value = ["1"]
                ROI_movie_radio_group.labels = LABELS
                ROI_movie_radio_group.active = 0
            else:
                print('Not enough files! Check save directory or calculate new stats.{:<100}'.format(' '),end="\r")

    # show/hide curves from selected/deselected labels for ROIs in statistics plots

    def update_ROI_plots():
        n_cnt = source.data["num_contours"][0]
        pbar = tqdm(total=len(stats)*n_cnt)
        for j in range(n_cnt):
            for i in range(len(stats)):
                name = 'ROI '+str(int(j)+1)
                glyph = p_stats[i].select_one({'name': name})
                if str(j+1) in ROI_multi_select.value:
                    glyph.visible = True
                else:
                    glyph.visible = False
                pbar.update(1)

    # load and display the selected ROI's mp4

    def load_ROI_movie():
        idx = ROI_movie_radio_group.active
        in_fname = 'ROI'+str(idx+1)+'.mp4'
        inPath = os.path.join(source.data['output_dir'][0],in_fname)
        if not os.path.exists(inPath):
            print('ROI movie not found! Check save directory or create ROI movie.',end="\r")
        else:
            old_plot = p_ROI.select_one({'name': sourceROI.data['img_name'][0]})
            if old_plot:
                p_ROI.renderers.remove(old_plot)
            vidcap = cv2.VideoCapture(inPath)
            last_frame = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
            ROI_movie_slider.end = (last_frame-1)*5/60
            ROI_movie_slider.value = 0
            vidcap.set(1, 0)
            success,frame = vidcap.read()
            img_tmp,_,__ = cv2.split(frame)
            h,w = np.shape(img_tmp)
            img = np.flipud(img_tmp)
            name = 'ROI'+str(idx+1)
            sourceROI.data = dict(image=[img],x=[0], y=[0], dw=[w], dh=[h],
                img_name=[name])
            p_ROI.image(source=sourceROI, image='image', x='x', y='y', 
               dw='dw', dh='dh', color_mapper=cmap, name='img_name')

    # change the displayed frame from slider movement

    def update_ROI_movie():
        frame_idx = round(ROI_movie_slider.value*60/5)
        in_fname = sourceROI.data['img_name'][0]+'.mp4'
        inPath = os.path.join(source.data['output_dir'][0],in_fname)
        vidcap = cv2.VideoCapture(inPath)
        vidcap.set(1, frame_idx)
        success,frame = vidcap.read()
        img_tmp,_,__ = cv2.split(frame)
        img = np.flipud(img_tmp)
        sourceROI.data['image'] = [img]

    # the following 2 functions are used to animate the mp4

    def update_ROI_slider():
        time = ROI_movie_slider.value + 5/60
        end = ROI_movie_slider.end
        if time > end:
            animate_ROI_movie()
        else:
            ROI_movie_slider.value = time
        return callback_id

    def animate_ROI_movie():
        global callback_id
        if ROI_movie_play_button.label == '► Play':
            ROI_movie_play_button.label = '❚❚ Pause'
            callback_id = curdoc().add_periodic_callback(update_ROI_slider, 10)
        else:
            ROI_movie_play_button.label = '► Play'
            curdoc().remove_periodic_callback(callback_id)
        return callback_id

### Application Content ###

    # main plot for segmentation and contour finding

    cmap = LinearColorMapper(palette="Greys256", low=0, high=255)
    TOOLS = "pan,wheel_zoom,box_zoom,reset,save,box_select,lasso_select"
    IMG_TOOLTIPS = [('name', "@img_name"),("x", "$x"),("y", "$y"),("value", "@image")]

    source = ColumnDataSource(data=dict(image=[0],bin_img=[0],image_orig=[0],
            x=[0], y=[0], dw=[0], dh=[0], num_contours=[0], roi_coords=[0],
            input_dir=[''],output_dir=[''],img_name=['']))
    source_label = ColumnDataSource(data=dict(x=[0], y=[0], label=['']))
    source_contours = ColumnDataSource(data=dict(xs=[0], ys=[0]))

    roi_labels = LabelSet(x='x', y='y', text='label',source=source_label, 
        level='annotation',text_color='white',text_font_size='12pt')

    # create a new plot and add a renderer
    p = figure(tools=TOOLS, toolbar_location=("right"))
    p.add_layout(roi_labels)
    p.x_range.range_padding = p.y_range.range_padding = 0

    # turn off gridlines
    p.xgrid.grid_line_color = None
    p.ygrid.grid_line_color = None
    p.axis.visible = False


    # ROI plots 

    sourceROI = ColumnDataSource(data=dict(image=[0],
            x=[0], y=[0], dw=[0], dh=[0], img_name=[0]))
    sources_stats = {}

    TOOLTIPS = [('name','$name'),('time', '@time'),('stat', "$y")]
    stats = np.array(['mean','var','min','max','median','skew','kurt','rawDensity','COMx','COMy'])
    p_stats = []
    tabs = []
    for i in range(len(stats)):
        p_stats.append(figure(tools=TOOLS, plot_height=300, plot_width=600))
        p_stats[i].x_range.range_padding = p_stats[i].y_range.range_padding = 0
        tabs.append(Panel(child=p_stats[i], title=stats[i]))

    # create a new plot and add a renderer
    p_ROI = figure(tools=TOOLS, toolbar_location=("right"), plot_height=300, plot_width=300)
    p_ROI.x_range.range_padding = p_ROI.y_range.range_padding = 0  

    # turn off gridlines
    p_ROI.xgrid.grid_line_color = p_ROI.ygrid.grid_line_color = None
    p_ROI.axis.visible = False


    # Widgets - Buttons, Sliders, Text, Etc.

    intro = Div(text="""<h2>Droplet Recognition and Analysis with Bokeh</h2> 
        This application is designed to help segment a grayscale image into 
        regions of interest (ROIs) and perform analysis on those regions.<br>
        <h4>How to Use This Application:</h4>
        <ol>
        <li>Load in a grayscale mp4 file and choose a save directory.</li>
        <li>Apply various filters for thresholding. Use <b>Close</b>, <b>Dilate</b> 
        and <b>Erode</b> buttons to adjust each binary image further.</li>
        <li>Use <b>Find Contours</b> button to search the image for closed shape. 
        The <b>Contour Size Range</b> slider will change size of the perimeter to
        be identified. You can apply new thresholds and repeat until satisfied with
        the region selection. Total regions detected is displayed next to
        the button.</li>
        <li>When satisfied, use <b>Export ROIs</b> to write ROI locations and 
        contour finding parameters to file.</li>
        <li><b>Create ROI Movies</b> to write mp4s of the selected regions.</li>
        <li>Use <b>Calculate ROI Stats</b> to perform calculations on the 
        newly created mp4 files.</li>
        <li>Finally, use <b>Load ROI Files</b> to load in the data that you just
        created and view the plots. The statistics plots can be overlaid by 
        selecting multiple labels. Individual ROI mp4s can be animated or you can
        use the slider to move through the frames.</li>
        </ol>
        Note: messages and progress bars are displayed below the GUI.""",
        style={'font-size':'10pt'},width=1000)

    file_button = Button(label="Choose File",button_type="primary")
    file_button.on_click(update_filename)
    inFile = PreText(text='Input File:\n'+source.data["img_name"][0], background=(255,255,255,0.5), width=500)

    filter_LABELS = ["Original","OTSU", "Isodata", "Mean", "Li","Yen","Local"]
    radio_button_gp = RadioButtonGroup(labels=filter_LABELS, active=0, width=600)
    radio_button_gp.on_change('active', lambda attr, old, new: apply_filter())
    
    offset_spinner = Spinner(low=0, high=500, value=1, step=1, width=100, title="Local: Offset",
                            background=(255,255,255,0.5))
    offset_spinner.on_change('value', lambda attr, old, new: apply_filter())
    block_spinner = Spinner(low=1, high=101, value=25, step=2, width=100, title="Local: Block Size",
                           background=(255,255,255,0.5))
    block_spinner.on_change('value', lambda attr, old, new: apply_filter())
    
    closing_button = Button(label="Close",button_type="default", width=100)
    closing_button.on_click(close_img)
    dilation_button = Button(label="Dilate",button_type="default", width=100)
    dilation_button.on_click(dilate_img)
    erosion_button = Button(label="Erode",button_type="default", width=100)
    erosion_button.on_click(erode_img)

    contour_rng_slider = RangeSlider(start=10, end=500, value=(200,350), step=1, width=300, 
            title="Contour Size Range", background=(255,255,255,0.5), bar_color='gray')
    contour_button = Button(label="Find Contours", button_type="success")
    contour_button.on_click(find_contours)
    contours_found = PreText(text='Droplets Detected: '+str(source.data["num_contours"][0]), background=(255,255,255,0.5))
    
    exportROIs_button = Button(label="Export ROIs", button_type="success", width=200)
    exportROIs_button.on_click(export_ROIs)    

    changeDir_button = Button(label="Change Directory",button_type="primary", width=150)
    changeDir_button.on_click(change_directory)
    outDir = PreText(text='Save Directory:\n'+source.data["output_dir"][0], background=(255,255,255,0.5), width=500)

    create_ROIs_button = Button(label="Create ROI Movies",button_type="success", width=200)
    create_ROIs_button.on_click(create_ROI_movies)

    process_ROIs_button = Button(label="Calculate ROI Stats",button_type="success")
    process_ROIs_button.on_click(process_ROIs)

    display_rng_text = figure(title="Display Range", title_location="left", 
                        width=40, height=300, toolbar_location=None, min_border=0, 
                        outline_line_color=None)
    display_rng_text.title.align="center"
    display_rng_text.title.text_font_size = '10pt'
    display_rng_text.x_range.range_padding = display_rng_text.y_range.range_padding = 0

    display_range_slider = RangeSlider(start=0, end=255, value=(0,255), step=1, 
        orientation='vertical', direction='rtl', 
        bar_color='gray', width=40, height=300, tooltips=True)
    display_range_slider.on_change('value', lambda attr, old, new: update_image())

    load_ROIfiles_button = Button(label="Load ROI Files",button_type="primary")
    load_ROIfiles_button.on_click(load_ROI_files)

    ROI_multi_select = MultiSelect(value=[], width=100, height=300)
    ROI_multi_select.on_change('value', lambda attr, old, new: update_ROI_plots())

    ROI_movie_radio_group = RadioGroup(labels=[],width=60)
    ROI_movie_radio_group.on_change('active', lambda attr, old, new: load_ROI_movie())
    ROI_movie_slider = Slider(start=0,end=100,value=0,step=5/60,title="Time [h]", width=280)
    ROI_movie_slider.on_change('value', lambda attr, old, new: update_ROI_movie())

    callback_id = None

    ROI_movie_play_button = Button(label='► Play',width=50)
    ROI_movie_play_button.on_click(animate_ROI_movie)

# initialize some data without having to choose file
    # fname = os.path.join(os.getcwd(),'data','Droplets.mp4')
    # load_data(filename=fname)


### Layout & Initialize application ###

    ROI_layout = layout([
        [ROI_movie_radio_group, p_ROI],
        [ROI_movie_slider,ROI_movie_play_button]
        ])

    app = layout(children=[
        [intro],
        [file_button,inFile],
        [radio_button_gp, offset_spinner, block_spinner],
        [closing_button, dilation_button, erosion_button],
        [contour_rng_slider, contour_button, contours_found],
        [exportROIs_button, outDir, changeDir_button],
        [create_ROIs_button, process_ROIs_button],
        [display_rng_text, display_range_slider, p],
        [load_ROIfiles_button],
        [ROI_layout, ROI_multi_select, Tabs(tabs=tabs)]
    ])
    
    doc.add_root(app)
Exemplo n.º 19
0
def create():
    det_data = {}
    fit_params = {}
    js_data = ColumnDataSource(
        data=dict(content=["", ""], fname=["", ""], ext=["", ""]))

    def proposal_textinput_callback(_attr, _old, new):
        proposal = new.strip()
        for zebra_proposals_path in pyzebra.ZEBRA_PROPOSALS_PATHS:
            proposal_path = os.path.join(zebra_proposals_path, proposal)
            if os.path.isdir(proposal_path):
                # found it
                break
        else:
            raise ValueError(f"Can not find data for proposal '{proposal}'.")

        file_list = []
        for file in os.listdir(proposal_path):
            if file.endswith((".ccl", ".dat")):
                file_list.append((os.path.join(proposal_path, file), file))
        file_select.options = file_list
        file_open_button.disabled = False
        file_append_button.disabled = False

    proposal_textinput = TextInput(title="Proposal number:", width=210)
    proposal_textinput.on_change("value", proposal_textinput_callback)

    def _init_datatable():
        scan_list = [s["idx"] for s in det_data]
        hkl = [f'{s["h"]} {s["k"]} {s["l"]}' for s in det_data]
        export = [s.get("active", True) for s in det_data]
        scan_table_source.data.update(
            scan=scan_list,
            hkl=hkl,
            fit=[0] * len(scan_list),
            export=export,
        )
        scan_table_source.selected.indices = []
        scan_table_source.selected.indices = [0]

        merge_options = [(str(i), f"{i} ({idx})")
                         for i, idx in enumerate(scan_list)]
        merge_from_select.options = merge_options
        merge_from_select.value = merge_options[0][0]

    file_select = MultiSelect(title="Available .ccl/.dat files:",
                              width=210,
                              height=250)

    def file_open_button_callback():
        nonlocal det_data
        det_data = []
        for f_name in file_select.value:
            with open(f_name) as file:
                base, ext = os.path.splitext(f_name)
                if det_data:
                    append_data = pyzebra.parse_1D(file, ext)
                    pyzebra.normalize_dataset(append_data,
                                              monitor_spinner.value)
                    pyzebra.merge_datasets(det_data, append_data)
                else:
                    det_data = pyzebra.parse_1D(file, ext)
                    pyzebra.normalize_dataset(det_data, monitor_spinner.value)
                    pyzebra.merge_duplicates(det_data)
                    js_data.data.update(fname=[base, base])

        _init_datatable()
        append_upload_button.disabled = False

    file_open_button = Button(label="Open New", width=100, disabled=True)
    file_open_button.on_click(file_open_button_callback)

    def file_append_button_callback():
        for f_name in file_select.value:
            with open(f_name) as file:
                _, ext = os.path.splitext(f_name)
                append_data = pyzebra.parse_1D(file, ext)

            pyzebra.normalize_dataset(append_data, monitor_spinner.value)
            pyzebra.merge_datasets(det_data, append_data)

        _init_datatable()

    file_append_button = Button(label="Append", width=100, disabled=True)
    file_append_button.on_click(file_append_button_callback)

    def upload_button_callback(_attr, _old, new):
        nonlocal det_data
        det_data = []
        for f_str, f_name in zip(new, upload_button.filename):
            with io.StringIO(base64.b64decode(f_str).decode()) as file:
                base, ext = os.path.splitext(f_name)
                if det_data:
                    append_data = pyzebra.parse_1D(file, ext)
                    pyzebra.normalize_dataset(append_data,
                                              monitor_spinner.value)
                    pyzebra.merge_datasets(det_data, append_data)
                else:
                    det_data = pyzebra.parse_1D(file, ext)
                    pyzebra.normalize_dataset(det_data, monitor_spinner.value)
                    pyzebra.merge_duplicates(det_data)
                    js_data.data.update(fname=[base, base])

        _init_datatable()
        append_upload_button.disabled = False

    upload_div = Div(text="or upload new .ccl/.dat files:",
                     margin=(5, 5, 0, 5))
    upload_button = FileInput(accept=".ccl,.dat", multiple=True, width=200)
    upload_button.on_change("value", upload_button_callback)

    def append_upload_button_callback(_attr, _old, new):
        for f_str, f_name in zip(new, append_upload_button.filename):
            with io.StringIO(base64.b64decode(f_str).decode()) as file:
                _, ext = os.path.splitext(f_name)
                append_data = pyzebra.parse_1D(file, ext)

            pyzebra.normalize_dataset(append_data, monitor_spinner.value)
            pyzebra.merge_datasets(det_data, append_data)

        _init_datatable()

    append_upload_div = Div(text="append extra files:", margin=(5, 5, 0, 5))
    append_upload_button = FileInput(accept=".ccl,.dat",
                                     multiple=True,
                                     width=200,
                                     disabled=True)
    append_upload_button.on_change("value", append_upload_button_callback)

    def monitor_spinner_callback(_attr, old, new):
        if det_data:
            pyzebra.normalize_dataset(det_data, new)
            _update_plot(_get_selected_scan())

    monitor_spinner = Spinner(title="Monitor:",
                              mode="int",
                              value=100_000,
                              low=1,
                              width=145)
    monitor_spinner.on_change("value", monitor_spinner_callback)

    def _update_table():
        fit_ok = [(1 if "fit" in scan else 0) for scan in det_data]
        scan_table_source.data.update(fit=fit_ok)

    def _update_plot(scan):
        scan_motor = scan["scan_motor"]

        y = scan["counts"]
        x = scan[scan_motor]

        plot.axis[0].axis_label = scan_motor
        plot_scatter_source.data.update(x=x,
                                        y=y,
                                        y_upper=y + np.sqrt(y),
                                        y_lower=y - np.sqrt(y))

        fit = scan.get("fit")
        if fit is not None:
            x_fit = np.linspace(x[0], x[-1], 100)
            plot_fit_source.data.update(x=x_fit, y=fit.eval(x=x_fit))

            x_bkg = []
            y_bkg = []
            xs_peak = []
            ys_peak = []
            comps = fit.eval_components(x=x_fit)
            for i, model in enumerate(fit_params):
                if "linear" in model:
                    x_bkg = x_fit
                    y_bkg = comps[f"f{i}_"]

                elif any(val in model
                         for val in ("gaussian", "voigt", "pvoigt")):
                    xs_peak.append(x_fit)
                    ys_peak.append(comps[f"f{i}_"])

            plot_bkg_source.data.update(x=x_bkg, y=y_bkg)
            plot_peak_source.data.update(xs=xs_peak, ys=ys_peak)

            fit_output_textinput.value = fit.fit_report()

        else:
            plot_fit_source.data.update(x=[], y=[])
            plot_bkg_source.data.update(x=[], y=[])
            plot_peak_source.data.update(xs=[], ys=[])
            fit_output_textinput.value = ""

    # Main plot
    plot = Plot(
        x_range=DataRange1d(),
        y_range=DataRange1d(only_visible=True),
        plot_height=470,
        plot_width=700,
    )

    plot.add_layout(LinearAxis(axis_label="Counts"), place="left")
    plot.add_layout(LinearAxis(axis_label="Scan motor"), place="below")

    plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    plot_scatter_source = ColumnDataSource(
        dict(x=[0], y=[0], y_upper=[0], y_lower=[0]))
    plot_scatter = plot.add_glyph(
        plot_scatter_source, Scatter(x="x", y="y", line_color="steelblue"))
    plot.add_layout(
        Whisker(source=plot_scatter_source,
                base="x",
                upper="y_upper",
                lower="y_lower"))

    plot_fit_source = ColumnDataSource(dict(x=[0], y=[0]))
    plot_fit = plot.add_glyph(plot_fit_source, Line(x="x", y="y"))

    plot_bkg_source = ColumnDataSource(dict(x=[0], y=[0]))
    plot_bkg = plot.add_glyph(
        plot_bkg_source,
        Line(x="x", y="y", line_color="green", line_dash="dashed"))

    plot_peak_source = ColumnDataSource(dict(xs=[[0]], ys=[[0]]))
    plot_peak = plot.add_glyph(
        plot_peak_source,
        MultiLine(xs="xs", ys="ys", line_color="red", line_dash="dashed"))

    fit_from_span = Span(location=None, dimension="height", line_dash="dashed")
    plot.add_layout(fit_from_span)

    fit_to_span = Span(location=None, dimension="height", line_dash="dashed")
    plot.add_layout(fit_to_span)

    plot.add_layout(
        Legend(
            items=[
                ("data", [plot_scatter]),
                ("best fit", [plot_fit]),
                ("peak", [plot_peak]),
                ("linear", [plot_bkg]),
            ],
            location="top_left",
            click_policy="hide",
        ))

    plot.add_tools(PanTool(), WheelZoomTool(), ResetTool())
    plot.toolbar.logo = None

    # Scan select
    def scan_table_select_callback(_attr, old, new):
        if not new:
            # skip empty selections
            return

        # Avoid selection of multiple indicies (via Shift+Click or Ctrl+Click)
        if len(new) > 1:
            # drop selection to the previous one
            scan_table_source.selected.indices = old
            return

        if len(old) > 1:
            # skip unnecessary update caused by selection drop
            return

        _update_plot(det_data[new[0]])

    def scan_table_source_callback(_attr, _old, _new):
        _update_preview()

    scan_table_source = ColumnDataSource(
        dict(scan=[], hkl=[], fit=[], export=[]))
    scan_table_source.on_change("data", scan_table_source_callback)

    scan_table = DataTable(
        source=scan_table_source,
        columns=[
            TableColumn(field="scan", title="Scan", width=50),
            TableColumn(field="hkl", title="hkl", width=100),
            TableColumn(field="fit", title="Fit", width=50),
            TableColumn(field="export",
                        title="Export",
                        editor=CheckboxEditor(),
                        width=50),
        ],
        width=310,  # +60 because of the index column
        height=350,
        autosize_mode="none",
        editable=True,
    )

    scan_table_source.selected.on_change("indices", scan_table_select_callback)

    def _get_selected_scan():
        return det_data[scan_table_source.selected.indices[0]]

    merge_from_select = Select(title="scan:", width=145)

    def merge_button_callback():
        scan_into = _get_selected_scan()
        scan_from = det_data[int(merge_from_select.value)]

        if scan_into is scan_from:
            print("WARNING: Selected scans for merging are identical")
            return

        pyzebra.merge_scans(scan_into, scan_from)
        _update_plot(_get_selected_scan())

    merge_button = Button(label="Merge into current", width=145)
    merge_button.on_click(merge_button_callback)

    def restore_button_callback():
        pyzebra.restore_scan(_get_selected_scan())
        _update_plot(_get_selected_scan())

    restore_button = Button(label="Restore scan", width=145)
    restore_button.on_click(restore_button_callback)

    def fit_from_spinner_callback(_attr, _old, new):
        fit_from_span.location = new

    fit_from_spinner = Spinner(title="Fit from:", width=145)
    fit_from_spinner.on_change("value", fit_from_spinner_callback)

    def fit_to_spinner_callback(_attr, _old, new):
        fit_to_span.location = new

    fit_to_spinner = Spinner(title="to:", width=145)
    fit_to_spinner.on_change("value", fit_to_spinner_callback)

    def fitparams_add_dropdown_callback(click):
        # bokeh requires (str, str) for MultiSelect options
        new_tag = f"{click.item}-{fitparams_select.tags[0]}"
        fitparams_select.options.append((new_tag, click.item))
        fit_params[new_tag] = fitparams_factory(click.item)
        fitparams_select.tags[0] += 1

    fitparams_add_dropdown = Dropdown(
        label="Add fit function",
        menu=[
            ("Linear", "linear"),
            ("Gaussian", "gaussian"),
            ("Voigt", "voigt"),
            ("Pseudo Voigt", "pvoigt"),
            # ("Pseudo Voigt1", "pseudovoigt1"),
        ],
        width=145,
    )
    fitparams_add_dropdown.on_click(fitparams_add_dropdown_callback)

    def fitparams_select_callback(_attr, old, new):
        # Avoid selection of multiple indicies (via Shift+Click or Ctrl+Click)
        if len(new) > 1:
            # drop selection to the previous one
            fitparams_select.value = old
            return

        if len(old) > 1:
            # skip unnecessary update caused by selection drop
            return

        if new:
            fitparams_table_source.data.update(fit_params[new[0]])
        else:
            fitparams_table_source.data.update(
                dict(param=[], value=[], vary=[], min=[], max=[]))

    fitparams_select = MultiSelect(options=[], height=120, width=145)
    fitparams_select.tags = [0]
    fitparams_select.on_change("value", fitparams_select_callback)

    def fitparams_remove_button_callback():
        if fitparams_select.value:
            sel_tag = fitparams_select.value[0]
            del fit_params[sel_tag]
            for elem in fitparams_select.options:
                if elem[0] == sel_tag:
                    fitparams_select.options.remove(elem)
                    break

            fitparams_select.value = []

    fitparams_remove_button = Button(label="Remove fit function", width=145)
    fitparams_remove_button.on_click(fitparams_remove_button_callback)

    def fitparams_factory(function):
        if function == "linear":
            params = ["slope", "intercept"]
        elif function == "gaussian":
            params = ["amplitude", "center", "sigma"]
        elif function == "voigt":
            params = ["amplitude", "center", "sigma", "gamma"]
        elif function == "pvoigt":
            params = ["amplitude", "center", "sigma", "fraction"]
        elif function == "pseudovoigt1":
            params = ["amplitude", "center", "g_sigma", "l_sigma", "fraction"]
        else:
            raise ValueError("Unknown fit function")

        n = len(params)
        fitparams = dict(
            param=params,
            value=[None] * n,
            vary=[True] * n,
            min=[None] * n,
            max=[None] * n,
        )

        if function == "linear":
            fitparams["value"] = [0, 1]
            fitparams["vary"] = [False, True]
            fitparams["min"] = [None, 0]

        elif function == "gaussian":
            fitparams["min"] = [0, None, None]

        return fitparams

    fitparams_table_source = ColumnDataSource(
        dict(param=[], value=[], vary=[], min=[], max=[]))
    fitparams_table = DataTable(
        source=fitparams_table_source,
        columns=[
            TableColumn(field="param", title="Parameter"),
            TableColumn(field="value", title="Value", editor=NumberEditor()),
            TableColumn(field="vary", title="Vary", editor=CheckboxEditor()),
            TableColumn(field="min", title="Min", editor=NumberEditor()),
            TableColumn(field="max", title="Max", editor=NumberEditor()),
        ],
        height=200,
        width=350,
        index_position=None,
        editable=True,
        auto_edit=True,
    )

    # start with `background` and `gauss` fit functions added
    fitparams_add_dropdown_callback(types.SimpleNamespace(item="linear"))
    fitparams_add_dropdown_callback(types.SimpleNamespace(item="gaussian"))
    fitparams_select.value = ["gaussian-1"]  # add selection to gauss

    fit_output_textinput = TextAreaInput(title="Fit results:",
                                         width=750,
                                         height=200)

    def proc_all_button_callback():
        for scan, export in zip(det_data, scan_table_source.data["export"]):
            if export:
                pyzebra.fit_scan(scan,
                                 fit_params,
                                 fit_from=fit_from_spinner.value,
                                 fit_to=fit_to_spinner.value)
                pyzebra.get_area(
                    scan,
                    area_method=AREA_METHODS[area_method_radiobutton.active],
                    lorentz=lorentz_checkbox.active,
                )

        _update_plot(_get_selected_scan())
        _update_table()

    proc_all_button = Button(label="Process All",
                             button_type="primary",
                             width=145)
    proc_all_button.on_click(proc_all_button_callback)

    def proc_button_callback():
        scan = _get_selected_scan()
        pyzebra.fit_scan(scan,
                         fit_params,
                         fit_from=fit_from_spinner.value,
                         fit_to=fit_to_spinner.value)
        pyzebra.get_area(
            scan,
            area_method=AREA_METHODS[area_method_radiobutton.active],
            lorentz=lorentz_checkbox.active,
        )

        _update_plot(scan)
        _update_table()

    proc_button = Button(label="Process Current", width=145)
    proc_button.on_click(proc_button_callback)

    area_method_div = Div(text="Intensity:", margin=(5, 5, 0, 5))
    area_method_radiobutton = RadioGroup(labels=["Function", "Area"],
                                         active=0,
                                         width=145)

    lorentz_checkbox = CheckboxGroup(labels=["Lorentz Correction"],
                                     width=145,
                                     margin=(13, 5, 5, 5))

    export_preview_textinput = TextAreaInput(title="Export file preview:",
                                             width=500,
                                             height=400)

    def _update_preview():
        with tempfile.TemporaryDirectory() as temp_dir:
            temp_file = temp_dir + "/temp"
            export_data = []
            for s, export in zip(det_data, scan_table_source.data["export"]):
                if export:
                    export_data.append(s)

            pyzebra.export_1D(
                export_data,
                temp_file,
                export_target_select.value,
                hkl_precision=int(hkl_precision_select.value),
            )

            exported_content = ""
            file_content = []
            for ext in EXPORT_TARGETS[export_target_select.value]:
                fname = temp_file + ext
                if os.path.isfile(fname):
                    with open(fname) as f:
                        content = f.read()
                        exported_content += f"{ext} file:\n" + content
                else:
                    content = ""
                file_content.append(content)

            js_data.data.update(content=file_content)
            export_preview_textinput.value = exported_content

    def export_target_select_callback(_attr, _old, new):
        js_data.data.update(ext=EXPORT_TARGETS[new])
        _update_preview()

    export_target_select = Select(title="Export target:",
                                  options=list(EXPORT_TARGETS.keys()),
                                  value="fullprof",
                                  width=80)
    export_target_select.on_change("value", export_target_select_callback)
    js_data.data.update(ext=EXPORT_TARGETS[export_target_select.value])

    def hkl_precision_select_callback(_attr, _old, _new):
        _update_preview()

    hkl_precision_select = Select(title="hkl precision:",
                                  options=["2", "3", "4"],
                                  value="2",
                                  width=80)
    hkl_precision_select.on_change("value", hkl_precision_select_callback)

    save_button = Button(label="Download File(s)",
                         button_type="success",
                         width=200)
    save_button.js_on_click(
        CustomJS(args={"js_data": js_data}, code=javaScript))

    fitpeak_controls = row(
        column(fitparams_add_dropdown, fitparams_select,
               fitparams_remove_button),
        fitparams_table,
        Spacer(width=20),
        column(fit_from_spinner, lorentz_checkbox, area_method_div,
               area_method_radiobutton),
        column(fit_to_spinner, proc_button, proc_all_button),
    )

    scan_layout = column(
        scan_table,
        row(monitor_spinner, column(Spacer(height=19), restore_button)),
        row(column(Spacer(height=19), merge_button), merge_from_select),
    )

    import_layout = column(
        proposal_textinput,
        file_select,
        row(file_open_button, file_append_button),
        upload_div,
        upload_button,
        append_upload_div,
        append_upload_button,
    )

    export_layout = column(
        export_preview_textinput,
        row(export_target_select, hkl_precision_select,
            column(Spacer(height=19), row(save_button))),
    )

    tab_layout = column(
        row(import_layout, scan_layout, plot, Spacer(width=30), export_layout),
        row(fitpeak_controls, fit_output_textinput),
    )

    return Panel(child=tab_layout, title="ccl integrate")
Exemplo n.º 20
0
    def __init__(self, image_views, disp_min=0, disp_max=1000, colormap="plasma"):
        """Initialize a colormapper.

        Args:
            image_views (ImageView): Associated streamvis image view instances.
            disp_min (int, optional): Initial minimal display value. Defaults to 0.
            disp_max (int, optional): Initial maximal display value. Defaults to 1000.
            colormap (str, optional): Initial colormap. Defaults to 'plasma'.
        """
        lin_colormapper = LinearColorMapper(
            palette=cmap_dict[colormap], low=disp_min, high=disp_max
        )

        log_colormapper = LogColorMapper(palette=cmap_dict[colormap], low=disp_min, high=disp_max)

        for image_view in image_views:
            image_view.image_glyph.color_mapper = lin_colormapper

        color_bar = ColorBar(
            color_mapper=lin_colormapper,
            location=(0, -5),
            orientation="horizontal",
            height=15,
            width=100,
            padding=5,
        )
        self.color_bar = color_bar

        # ---- selector
        def select_callback(_attr, _old, new):
            if new in cmap_dict:
                lin_colormapper.palette = cmap_dict[new]
                log_colormapper.palette = cmap_dict[new]
                high_color.color = cmap_dict[new][-1]

        select = Select(
            title="Colormap:", value=colormap, options=list(cmap_dict.keys()), default_size=100
        )
        select.on_change("value", select_callback)
        self.select = select

        # ---- auto toggle button
        def auto_toggle_callback(state):
            if state:
                display_min_spinner.disabled = True
                display_max_spinner.disabled = True
            else:
                display_min_spinner.disabled = False
                display_max_spinner.disabled = False

        auto_toggle = CheckboxGroup(labels=["Auto Colormap Range"], default_size=145)
        auto_toggle.on_click(auto_toggle_callback)
        self.auto_toggle = auto_toggle

        # ---- scale radiobutton group
        def scale_radiobuttongroup_callback(selection):
            if selection == 0:  # Linear
                for image_view in image_views:
                    image_view.image_glyph.color_mapper = lin_colormapper
                color_bar.color_mapper = lin_colormapper
                color_bar.ticker = BasicTicker()

            else:  # Logarithmic
                if self.disp_min > 0:
                    for image_view in image_views:
                        image_view.image_glyph.color_mapper = log_colormapper
                    color_bar.color_mapper = log_colormapper
                    color_bar.ticker = LogTicker()
                else:
                    scale_radiobuttongroup.active = 0

        scale_radiobuttongroup = RadioGroup(
            labels=["Linear", "Logarithmic"], active=0, default_size=145
        )
        scale_radiobuttongroup.on_click(scale_radiobuttongroup_callback)
        self.scale_radiobuttongroup = scale_radiobuttongroup

        # ---- display max value
        def display_max_spinner_callback(_attr, _old_value, new_value):
            self.display_min_spinner.high = new_value - STEP
            if new_value <= 0:
                scale_radiobuttongroup.active = 0

            lin_colormapper.high = new_value
            log_colormapper.high = new_value

        display_max_spinner = Spinner(
            title="Max Display Value:",
            low=disp_min + STEP,
            value=disp_max,
            step=STEP,
            disabled=bool(auto_toggle.active),
            default_size=145,
        )
        display_max_spinner.on_change("value", display_max_spinner_callback)
        self.display_max_spinner = display_max_spinner

        # ---- display min value
        def display_min_spinner_callback(_attr, _old_value, new_value):
            self.display_max_spinner.low = new_value + STEP
            if new_value <= 0:
                scale_radiobuttongroup.active = 0

            lin_colormapper.low = new_value
            log_colormapper.low = new_value

        display_min_spinner = Spinner(
            title="Min Display Value:",
            high=disp_max - STEP,
            value=disp_min,
            step=STEP,
            disabled=bool(auto_toggle.active),
            default_size=145,
        )
        display_min_spinner.on_change("value", display_min_spinner_callback)
        self.display_min_spinner = display_min_spinner

        # ---- high color
        def high_color_callback(_attr, _old_value, new_value):
            lin_colormapper.high_color = new_value
            log_colormapper.high_color = new_value

        high_color = ColorPicker(
            title="High Color:", color=cmap_dict[colormap][-1], default_size=90
        )
        high_color.on_change("color", high_color_callback)
        self.high_color = high_color

        # ---- mask color
        def mask_color_callback(_attr, _old_value, new_value):
            lin_colormapper.nan_color = new_value
            log_colormapper.nan_color = new_value

        mask_color = ColorPicker(title="Mask Color:", color="gray", default_size=90)
        mask_color.on_change("color", mask_color_callback)
        self.mask_color = mask_color
Exemplo n.º 21
0
def create(palm):
    fit_max = 1
    fit_min = 0

    doc = curdoc()

    # THz calibration plot
    scan_plot = Plot(
        title=Title(text="THz calibration"),
        x_range=DataRange1d(),
        y_range=DataRange1d(),
        plot_height=PLOT_CANVAS_HEIGHT,
        plot_width=PLOT_CANVAS_WIDTH,
        toolbar_location="right",
    )

    # ---- tools
    scan_plot.toolbar.logo = None
    scan_plot.add_tools(PanTool(), BoxZoomTool(), WheelZoomTool(), ResetTool())

    # ---- axes
    scan_plot.add_layout(LinearAxis(axis_label="Stage delay motor"),
                         place="below")
    scan_plot.add_layout(LinearAxis(axis_label="Energy shift, eV",
                                    major_label_orientation="vertical"),
                         place="left")

    # ---- grid lines
    scan_plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    scan_plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- circle cluster glyphs
    scan_circle_source = ColumnDataSource(dict(x=[], y=[]))
    scan_plot.add_glyph(scan_circle_source,
                        Circle(x="x", y="y", line_alpha=0, fill_alpha=0.5))

    # ---- circle glyphs
    scan_avg_circle_source = ColumnDataSource(dict(x=[], y=[]))
    scan_plot.add_glyph(
        scan_avg_circle_source,
        Circle(x="x", y="y", line_color="purple", fill_color="purple"))

    # ---- line glyphs
    fit_line_source = ColumnDataSource(dict(x=[], y=[]))
    scan_plot.add_glyph(fit_line_source, Line(x="x",
                                              y="y",
                                              line_color="purple"))

    # THz calibration folder path text input
    def path_textinput_callback(_attr, _old_value, _new_value):
        update_load_dropdown_menu()
        path_periodic_update()

    path_textinput = TextInput(title="THz calibration path:",
                               value=os.path.join(os.path.expanduser("~")),
                               width=510)
    path_textinput.on_change("value", path_textinput_callback)

    # THz calibration eco scans dropdown
    def scans_dropdown_callback(_attr, _old_value, new_value):
        scans_dropdown.label = new_value

    scans_dropdown = Dropdown(label="ECO scans",
                              button_type="default",
                              menu=[])
    scans_dropdown.on_change("value", scans_dropdown_callback)

    # ---- eco scans periodic update
    def path_periodic_update():
        new_menu = []
        if os.path.isdir(path_textinput.value):
            for entry in os.scandir(path_textinput.value):
                if entry.is_file() and entry.name.endswith(".json"):
                    new_menu.append((entry.name, entry.name))
        scans_dropdown.menu = sorted(new_menu, reverse=True)

    doc.add_periodic_callback(path_periodic_update, 5000)

    # Calibrate button
    def calibrate_button_callback():
        palm.calibrate_thz(
            path=os.path.join(path_textinput.value, scans_dropdown.value))
        fit_max_spinner.value = np.ceil(palm.thz_calib_data.index.values.max())
        fit_min_spinner.value = np.floor(
            palm.thz_calib_data.index.values.min())
        update_calibration_plot()

    def update_calibration_plot():
        scan_plot.xaxis.axis_label = f"{palm.thz_motor_name}, {palm.thz_motor_unit}"

        scan_circle_source.data.update(
            x=np.repeat(palm.thz_calib_data.index,
                        palm.thz_calib_data["peak_shift"].apply(len)).tolist(),
            y=np.concatenate(
                palm.thz_calib_data["peak_shift"].values).tolist(),
        )

        scan_avg_circle_source.data.update(
            x=palm.thz_calib_data.index.tolist(),
            y=palm.thz_calib_data["peak_shift_mean"].tolist())

        x = np.linspace(fit_min, fit_max, 100)
        y = palm.thz_slope * x + palm.thz_intersect
        fit_line_source.data.update(x=np.round(x, decimals=5),
                                    y=np.round(y, decimals=5))

        calib_const_div.text = f"""
        thz_slope = {palm.thz_slope}
        """

    calibrate_button = Button(label="Calibrate THz",
                              button_type="default",
                              width=250)
    calibrate_button.on_click(calibrate_button_callback)

    # THz fit maximal value text input
    def fit_max_spinner_callback(_attr, old_value, new_value):
        nonlocal fit_max
        if new_value > fit_min:
            fit_max = new_value
            palm.calibrate_thz(
                path=os.path.join(path_textinput.value, scans_dropdown.value),
                fit_range=(fit_min, fit_max),
            )
            update_calibration_plot()
        else:
            fit_max_spinner.value = old_value

    fit_max_spinner = Spinner(title="Maximal fit value:",
                              value=fit_max,
                              step=0.1)
    fit_max_spinner.on_change("value", fit_max_spinner_callback)

    # THz fit maximal value text input
    def fit_min_spinner_callback(_attr, old_value, new_value):
        nonlocal fit_min
        if new_value < fit_max:
            fit_min = new_value
            palm.calibrate_thz(
                path=os.path.join(path_textinput.value, scans_dropdown.value),
                fit_range=(fit_min, fit_max),
            )
            update_calibration_plot()
        else:
            fit_min_spinner.value = old_value

    fit_min_spinner = Spinner(title="Minimal fit value:",
                              value=fit_min,
                              step=0.1)
    fit_min_spinner.on_change("value", fit_min_spinner_callback)

    # Save calibration button
    def save_button_callback():
        palm.save_thz_calib(path=path_textinput.value)
        update_load_dropdown_menu()

    save_button = Button(label="Save", button_type="default", width=250)
    save_button.on_click(save_button_callback)

    # Load calibration button
    def load_dropdown_callback(_attr, _old_value, new_value):
        palm.load_thz_calib(os.path.join(path_textinput.value, new_value))
        update_calibration_plot()

    def update_load_dropdown_menu():
        new_menu = []
        calib_file_ext = ".palm_thz"
        if os.path.isdir(path_textinput.value):
            for entry in os.scandir(path_textinput.value):
                if entry.is_file() and entry.name.endswith((calib_file_ext)):
                    new_menu.append(
                        (entry.name[:-len(calib_file_ext)], entry.name))
            load_dropdown.button_type = "default"
            load_dropdown.menu = sorted(new_menu, reverse=True)
        else:
            load_dropdown.button_type = "danger"
            load_dropdown.menu = new_menu

    doc.add_next_tick_callback(update_load_dropdown_menu)
    doc.add_periodic_callback(update_load_dropdown_menu, 5000)

    load_dropdown = Dropdown(label="Load", menu=[], width=250)
    load_dropdown.on_change("value", load_dropdown_callback)

    # Calibration constants
    calib_const_div = Div(text=f"""
        thz_slope = {0}
        """)

    # assemble
    tab_layout = column(
        row(
            scan_plot,
            Spacer(width=30),
            column(
                path_textinput,
                scans_dropdown,
                calibrate_button,
                fit_max_spinner,
                fit_min_spinner,
                row(save_button, load_dropdown),
                calib_const_div,
            ),
        ))

    return Panel(child=tab_layout, title="THz Calibration")
Exemplo n.º 22
0
# Set up data table for summary statistics
datatable_columns = [
    TableColumn(field="parameter", title="Parameter"),
    TableColumn(field="value_all", title="All Data"),
    TableColumn(field="value_selection", title="Selected Data"),
]

data_table = DataTable(source=datatable_source,
                       columns=datatable_columns,
                       width=450,
                       height=125,
                       index_position=None)

# callback for updating the plot based on a changes to inputs
station_name_input.on_change('value', update_station)
simulation_number_input.on_change('value', update_n_simulations)
msmt_error_input.on_change('value', update_msmt_error)
sample_size_input.on_change('value', update_simulation_sample_size)
toggle_button.on_click(update_simulated_msmt_error)

# see documentation for threading information
# https://docs.bokeh.org/en/latest/docs/user_guide/server.html

update()

# widgets
ts_plot = create_ts_plot(peak_source, peak_flagged_source)

peak_source.selected.on_change('indices', update_UI)

vh1, pv, hist_source = create_vhist(peak_source, ts_plot)
Exemplo n.º 23
0
def create(palm):
    doc = curdoc()

    # Calibration averaged waveforms per photon energy
    waveform_plot = Plot(
        title=Title(text="eTOF calibration waveforms"),
        x_range=DataRange1d(),
        y_range=DataRange1d(),
        plot_height=760,
        plot_width=PLOT_CANVAS_WIDTH,
        toolbar_location="right",
    )

    # ---- tools
    waveform_plot.toolbar.logo = None
    waveform_plot_hovertool = HoverTool(
        tooltips=[("energy, eV", "@en"), ("eTOF bin", "$x{0.}")])

    waveform_plot.add_tools(PanTool(), BoxZoomTool(), WheelZoomTool(),
                            ResetTool(), waveform_plot_hovertool)

    # ---- axes
    waveform_plot.add_layout(LinearAxis(axis_label="eTOF time bin"),
                             place="below")
    waveform_plot.add_layout(LinearAxis(axis_label="Intensity",
                                        major_label_orientation="vertical"),
                             place="left")

    # ---- grid lines
    waveform_plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    waveform_plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- multiline glyphs
    waveform_ref_source = ColumnDataSource(dict(xs=[], ys=[], en=[]))
    waveform_ref_multiline = waveform_plot.add_glyph(
        waveform_ref_source, MultiLine(xs="xs", ys="ys", line_color="blue"))

    waveform_str_source = ColumnDataSource(dict(xs=[], ys=[], en=[]))
    waveform_str_multiline = waveform_plot.add_glyph(
        waveform_str_source, MultiLine(xs="xs", ys="ys", line_color="red"))

    # ---- legend
    waveform_plot.add_layout(
        Legend(items=[(
            "reference",
            [waveform_ref_multiline]), ("streaked",
                                        [waveform_str_multiline])]))
    waveform_plot.legend.click_policy = "hide"

    # ---- vertical spans
    photon_peak_ref_span = Span(location=0,
                                dimension="height",
                                line_dash="dashed",
                                line_color="blue")
    photon_peak_str_span = Span(location=0,
                                dimension="height",
                                line_dash="dashed",
                                line_color="red")
    waveform_plot.add_layout(photon_peak_ref_span)
    waveform_plot.add_layout(photon_peak_str_span)

    # Calibration fit plot
    fit_plot = Plot(
        title=Title(text="eTOF calibration fit"),
        x_range=DataRange1d(),
        y_range=DataRange1d(),
        plot_height=PLOT_CANVAS_HEIGHT,
        plot_width=PLOT_CANVAS_WIDTH,
        toolbar_location="right",
    )

    # ---- tools
    fit_plot.toolbar.logo = None
    fit_plot.add_tools(PanTool(), BoxZoomTool(), WheelZoomTool(), ResetTool())

    # ---- axes
    fit_plot.add_layout(LinearAxis(axis_label="Photoelectron peak shift"),
                        place="below")
    fit_plot.add_layout(LinearAxis(axis_label="Photon energy, eV",
                                   major_label_orientation="vertical"),
                        place="left")

    # ---- grid lines
    fit_plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    fit_plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- circle glyphs
    fit_ref_circle_source = ColumnDataSource(dict(x=[], y=[]))
    fit_ref_circle = fit_plot.add_glyph(
        fit_ref_circle_source, Circle(x="x", y="y", line_color="blue"))
    fit_str_circle_source = ColumnDataSource(dict(x=[], y=[]))
    fit_str_circle = fit_plot.add_glyph(fit_str_circle_source,
                                        Circle(x="x", y="y", line_color="red"))

    # ---- line glyphs
    fit_ref_line_source = ColumnDataSource(dict(x=[], y=[]))
    fit_ref_line = fit_plot.add_glyph(fit_ref_line_source,
                                      Line(x="x", y="y", line_color="blue"))
    fit_str_line_source = ColumnDataSource(dict(x=[], y=[]))
    fit_str_line = fit_plot.add_glyph(fit_str_line_source,
                                      Line(x="x", y="y", line_color="red"))

    # ---- legend
    fit_plot.add_layout(
        Legend(items=[
            ("reference", [fit_ref_circle, fit_ref_line]),
            ("streaked", [fit_str_circle, fit_str_line]),
        ]))
    fit_plot.legend.click_policy = "hide"

    # Calibration results datatables
    def datatable_ref_source_callback(_attr, _old_value, new_value):
        for en, ps, use in zip(new_value["energy"], new_value["peak_pos_ref"],
                               new_value["use_in_fit"]):
            palm.etofs["0"].calib_data.loc[
                en, "calib_tpeak"] = ps if ps != "NaN" else np.nan
            palm.etofs["0"].calib_data.loc[en, "use_in_fit"] = use

        calib_res = {}
        for etof_key in palm.etofs:
            calib_res[etof_key] = palm.etofs[etof_key].fit_calibration_curve()
        update_calibration_plot(calib_res)

    datatable_ref_source = ColumnDataSource(
        dict(energy=["", "", ""],
             peak_pos_ref=["", "", ""],
             use_in_fit=[True, True, True]))
    datatable_ref_source.on_change("data", datatable_ref_source_callback)

    datatable_ref = DataTable(
        source=datatable_ref_source,
        columns=[
            TableColumn(field="energy",
                        title="Photon Energy, eV",
                        editor=IntEditor()),
            TableColumn(field="peak_pos_ref",
                        title="Reference Peak",
                        editor=IntEditor()),
            TableColumn(field="use_in_fit",
                        title=" ",
                        editor=CheckboxEditor(),
                        width=80),
        ],
        index_position=None,
        editable=True,
        height=300,
        width=250,
    )

    def datatable_str_source_callback(_attr, _old_value, new_value):
        for en, ps, use in zip(new_value["energy"], new_value["peak_pos_str"],
                               new_value["use_in_fit"]):
            palm.etofs["1"].calib_data.loc[
                en, "calib_tpeak"] = ps if ps != "NaN" else np.nan
            palm.etofs["1"].calib_data.loc[en, "use_in_fit"] = use

        calib_res = {}
        for etof_key in palm.etofs:
            calib_res[etof_key] = palm.etofs[etof_key].fit_calibration_curve()
        update_calibration_plot(calib_res)

    datatable_str_source = ColumnDataSource(
        dict(energy=["", "", ""],
             peak_pos_str=["", "", ""],
             use_in_fit=[True, True, True]))
    datatable_str_source.on_change("data", datatable_str_source_callback)

    datatable_str = DataTable(
        source=datatable_str_source,
        columns=[
            TableColumn(field="energy",
                        title="Photon Energy, eV",
                        editor=IntEditor()),
            TableColumn(field="peak_pos_str",
                        title="Streaked Peak",
                        editor=IntEditor()),
            TableColumn(field="use_in_fit",
                        title=" ",
                        editor=CheckboxEditor(),
                        width=80),
        ],
        index_position=None,
        editable=True,
        height=350,
        width=250,
    )

    # eTOF calibration folder path text input
    def path_textinput_callback(_attr, _old_value, _new_value):
        path_periodic_update()
        update_load_dropdown_menu()

    path_textinput = TextInput(title="eTOF calibration path:",
                               value=os.path.join(os.path.expanduser("~")),
                               width=510)
    path_textinput.on_change("value", path_textinput_callback)

    # eTOF calibration eco scans dropdown
    def scans_dropdown_callback(_attr, _old_value, new_value):
        scans_dropdown.label = new_value

    scans_dropdown = Dropdown(label="ECO scans",
                              button_type="default",
                              menu=[])
    scans_dropdown.on_change("value", scans_dropdown_callback)

    # ---- etof scans periodic update
    def path_periodic_update():
        new_menu = []
        if os.path.isdir(path_textinput.value):
            for entry in os.scandir(path_textinput.value):
                if entry.is_file() and entry.name.endswith(".json"):
                    new_menu.append((entry.name, entry.name))
        scans_dropdown.menu = sorted(new_menu, reverse=True)

    doc.add_periodic_callback(path_periodic_update, 5000)

    # Calibrate button
    def calibrate_button_callback():
        try:
            palm.calibrate_etof_eco(eco_scan_filename=os.path.join(
                path_textinput.value, scans_dropdown.value))
        except Exception:
            palm.calibrate_etof(folder_name=path_textinput.value)

        datatable_ref_source.data.update(
            energy=palm.etofs["0"].calib_data.index.tolist(),
            peak_pos_ref=palm.etofs["0"].calib_data["calib_tpeak"].tolist(),
            use_in_fit=palm.etofs["0"].calib_data["use_in_fit"].tolist(),
        )

        datatable_str_source.data.update(
            energy=palm.etofs["0"].calib_data.index.tolist(),
            peak_pos_str=palm.etofs["1"].calib_data["calib_tpeak"].tolist(),
            use_in_fit=palm.etofs["1"].calib_data["use_in_fit"].tolist(),
        )

    def update_calibration_plot(calib_res):
        etof_ref = palm.etofs["0"]
        etof_str = palm.etofs["1"]

        shift_val = 0
        etof_ref_wf_shifted = []
        etof_str_wf_shifted = []
        for wf_ref, wf_str in zip(etof_ref.calib_data["waveform"],
                                  etof_str.calib_data["waveform"]):
            shift_val -= max(wf_ref.max(), wf_str.max())
            etof_ref_wf_shifted.append(wf_ref + shift_val)
            etof_str_wf_shifted.append(wf_str + shift_val)

        waveform_ref_source.data.update(
            xs=len(etof_ref.calib_data) *
            [list(range(etof_ref.internal_time_bins))],
            ys=etof_ref_wf_shifted,
            en=etof_ref.calib_data.index.tolist(),
        )

        waveform_str_source.data.update(
            xs=len(etof_str.calib_data) *
            [list(range(etof_str.internal_time_bins))],
            ys=etof_str_wf_shifted,
            en=etof_str.calib_data.index.tolist(),
        )

        photon_peak_ref_span.location = etof_ref.calib_t0
        photon_peak_str_span.location = etof_str.calib_t0

        def plot_fit(time, calib_a, calib_b):
            time_fit = np.linspace(np.nanmin(time), np.nanmax(time), 100)
            en_fit = (calib_a / time_fit)**2 + calib_b
            return time_fit, en_fit

        def update_plot(calib_results, circle, line):
            (a, c), x, y = calib_results
            x_fit, y_fit = plot_fit(x, a, c)
            circle.data.update(x=x, y=y)
            line.data.update(x=x_fit, y=y_fit)

        update_plot(calib_res["0"], fit_ref_circle_source, fit_ref_line_source)
        update_plot(calib_res["1"], fit_str_circle_source, fit_str_line_source)

        calib_const_div.text = f"""
        a_str = {etof_str.calib_a:.2f}<br>
        b_str = {etof_str.calib_b:.2f}<br>
        <br>
        a_ref = {etof_ref.calib_a:.2f}<br>
        b_ref = {etof_ref.calib_b:.2f}
        """

    calibrate_button = Button(label="Calibrate eTOF",
                              button_type="default",
                              width=250)
    calibrate_button.on_click(calibrate_button_callback)

    # Photon peak noise threshold value text input
    def phot_peak_noise_thr_spinner_callback(_attr, old_value, new_value):
        if new_value > 0:
            for etof in palm.etofs.values():
                etof.photon_peak_noise_thr = new_value
        else:
            phot_peak_noise_thr_spinner.value = old_value

    phot_peak_noise_thr_spinner = Spinner(title="Photon peak noise threshold:",
                                          value=1,
                                          step=0.1)
    phot_peak_noise_thr_spinner.on_change(
        "value", phot_peak_noise_thr_spinner_callback)

    # Electron peak noise threshold value text input
    def el_peak_noise_thr_spinner_callback(_attr, old_value, new_value):
        if new_value > 0:
            for etof in palm.etofs.values():
                etof.electron_peak_noise_thr = new_value
        else:
            el_peak_noise_thr_spinner.value = old_value

    el_peak_noise_thr_spinner = Spinner(title="Electron peak noise threshold:",
                                        value=10,
                                        step=0.1)
    el_peak_noise_thr_spinner.on_change("value",
                                        el_peak_noise_thr_spinner_callback)

    # Save calibration button
    def save_button_callback():
        palm.save_etof_calib(path=path_textinput.value)
        update_load_dropdown_menu()

    save_button = Button(label="Save", button_type="default", width=250)
    save_button.on_click(save_button_callback)

    # Load calibration button
    def load_dropdown_callback(_attr, _old_value, new_value):
        if new_value:
            palm.load_etof_calib(os.path.join(path_textinput.value, new_value))

            datatable_ref_source.data.update(
                energy=palm.etofs["0"].calib_data.index.tolist(),
                peak_pos_ref=palm.etofs["0"].calib_data["calib_tpeak"].tolist(
                ),
                use_in_fit=palm.etofs["0"].calib_data["use_in_fit"].tolist(),
            )

            datatable_str_source.data.update(
                energy=palm.etofs["0"].calib_data.index.tolist(),
                peak_pos_str=palm.etofs["1"].calib_data["calib_tpeak"].tolist(
                ),
                use_in_fit=palm.etofs["1"].calib_data["use_in_fit"].tolist(),
            )

            # Drop selection, so that this callback can be triggered again on the same dropdown menu
            # item from the user perspective
            load_dropdown.value = ""

    def update_load_dropdown_menu():
        new_menu = []
        calib_file_ext = ".palm_etof"
        if os.path.isdir(path_textinput.value):
            for entry in os.scandir(path_textinput.value):
                if entry.is_file() and entry.name.endswith((calib_file_ext)):
                    new_menu.append(
                        (entry.name[:-len(calib_file_ext)], entry.name))
            load_dropdown.button_type = "default"
            load_dropdown.menu = sorted(new_menu, reverse=True)
        else:
            load_dropdown.button_type = "danger"
            load_dropdown.menu = new_menu

    doc.add_next_tick_callback(update_load_dropdown_menu)
    doc.add_periodic_callback(update_load_dropdown_menu, 5000)

    load_dropdown = Dropdown(label="Load", menu=[], width=250)
    load_dropdown.on_change("value", load_dropdown_callback)

    # eTOF fitting equation
    fit_eq_div = Div(
        text="""Fitting equation:<br><br><img src="/palm/static/5euwuy.gif">"""
    )

    # Calibration constants
    calib_const_div = Div(text=f"""
        a_str = {0}<br>
        b_str = {0}<br>
        <br>
        a_ref = {0}<br>
        b_ref = {0}
        """)

    # assemble
    tab_layout = column(
        row(
            column(waveform_plot, fit_plot),
            Spacer(width=30),
            column(
                path_textinput,
                scans_dropdown,
                calibrate_button,
                phot_peak_noise_thr_spinner,
                el_peak_noise_thr_spinner,
                row(save_button, load_dropdown),
                row(datatable_ref, datatable_str),
                calib_const_div,
                fit_eq_div,
            ),
        ))

    return Panel(child=tab_layout, title="eTOF Calibration")
Exemplo n.º 24
0
 def create_standoff_spinner(self) -> Spinner:
     spinner = Spinner(title="Standoff", value=self.standoff, width=210)
     spinner.on_change("value", self.handle_standoff_change)
     return spinner
Exemplo n.º 25
0
def create():
    det_data = []
    fit_params = {}
    js_data = ColumnDataSource(data=dict(content=["", ""], fname=["", ""]))

    def proposal_textinput_callback(_attr, _old, new):
        proposal = new.strip()
        for zebra_proposals_path in pyzebra.ZEBRA_PROPOSALS_PATHS:
            proposal_path = os.path.join(zebra_proposals_path, proposal)
            if os.path.isdir(proposal_path):
                # found it
                break
        else:
            raise ValueError(f"Can not find data for proposal '{proposal}'.")

        file_list = []
        for file in os.listdir(proposal_path):
            if file.endswith((".ccl", ".dat")):
                file_list.append((os.path.join(proposal_path, file), file))
        file_select.options = file_list
        file_open_button.disabled = False
        file_append_button.disabled = False

    proposal_textinput = TextInput(title="Proposal number:", width=210)
    proposal_textinput.on_change("value", proposal_textinput_callback)

    def _init_datatable():
        scan_list = [s["idx"] for s in det_data]
        file_list = []
        for scan in det_data:
            file_list.append(os.path.basename(scan["original_filename"]))

        scan_table_source.data.update(
            file=file_list,
            scan=scan_list,
            param=[None] * len(scan_list),
            fit=[0] * len(scan_list),
            export=[True] * len(scan_list),
        )
        scan_table_source.selected.indices = []
        scan_table_source.selected.indices = [0]

        scan_motor_select.options = det_data[0]["scan_motors"]
        scan_motor_select.value = det_data[0]["scan_motor"]
        param_select.value = "user defined"

    file_select = MultiSelect(title="Available .ccl/.dat files:",
                              width=210,
                              height=250)

    def file_open_button_callback():
        nonlocal det_data
        det_data = []
        for f_name in file_select.value:
            with open(f_name) as file:
                base, ext = os.path.splitext(f_name)
                if det_data:
                    append_data = pyzebra.parse_1D(file, ext)
                    pyzebra.normalize_dataset(append_data,
                                              monitor_spinner.value)
                    det_data.extend(append_data)
                else:
                    det_data = pyzebra.parse_1D(file, ext)
                    pyzebra.normalize_dataset(det_data, monitor_spinner.value)
                    js_data.data.update(
                        fname=[base + ".comm", base + ".incomm"])

        _init_datatable()
        append_upload_button.disabled = False

    file_open_button = Button(label="Open New", width=100, disabled=True)
    file_open_button.on_click(file_open_button_callback)

    def file_append_button_callback():
        for f_name in file_select.value:
            with open(f_name) as file:
                _, ext = os.path.splitext(f_name)
                append_data = pyzebra.parse_1D(file, ext)

            pyzebra.normalize_dataset(append_data, monitor_spinner.value)
            det_data.extend(append_data)

        _init_datatable()

    file_append_button = Button(label="Append", width=100, disabled=True)
    file_append_button.on_click(file_append_button_callback)

    def upload_button_callback(_attr, _old, new):
        nonlocal det_data
        det_data = []
        for f_str, f_name in zip(new, upload_button.filename):
            with io.StringIO(base64.b64decode(f_str).decode()) as file:
                base, ext = os.path.splitext(f_name)
                if det_data:
                    append_data = pyzebra.parse_1D(file, ext)
                    pyzebra.normalize_dataset(append_data,
                                              monitor_spinner.value)
                    det_data.extend(append_data)
                else:
                    det_data = pyzebra.parse_1D(file, ext)
                    pyzebra.normalize_dataset(det_data, monitor_spinner.value)
                    js_data.data.update(
                        fname=[base + ".comm", base + ".incomm"])

        _init_datatable()
        append_upload_button.disabled = False

    upload_div = Div(text="or upload new .ccl/.dat files:",
                     margin=(5, 5, 0, 5))
    upload_button = FileInput(accept=".ccl,.dat", multiple=True, width=200)
    upload_button.on_change("value", upload_button_callback)

    def append_upload_button_callback(_attr, _old, new):
        for f_str, f_name in zip(new, append_upload_button.filename):
            with io.StringIO(base64.b64decode(f_str).decode()) as file:
                _, ext = os.path.splitext(f_name)
                append_data = pyzebra.parse_1D(file, ext)

            pyzebra.normalize_dataset(append_data, monitor_spinner.value)
            det_data.extend(append_data)

        _init_datatable()

    append_upload_div = Div(text="append extra files:", margin=(5, 5, 0, 5))
    append_upload_button = FileInput(accept=".ccl,.dat",
                                     multiple=True,
                                     width=200,
                                     disabled=True)
    append_upload_button.on_change("value", append_upload_button_callback)

    def monitor_spinner_callback(_attr, _old, new):
        if det_data:
            pyzebra.normalize_dataset(det_data, new)
            _update_plot()

    monitor_spinner = Spinner(title="Monitor:",
                              mode="int",
                              value=100_000,
                              low=1,
                              width=145)
    monitor_spinner.on_change("value", monitor_spinner_callback)

    def scan_motor_select_callback(_attr, _old, new):
        if det_data:
            for scan in det_data:
                scan["scan_motor"] = new
            _update_plot()

    scan_motor_select = Select(title="Scan motor:", options=[], width=145)
    scan_motor_select.on_change("value", scan_motor_select_callback)

    def _update_table():
        fit_ok = [(1 if "fit" in scan else 0) for scan in det_data]
        scan_table_source.data.update(fit=fit_ok)

    def _update_plot():
        _update_single_scan_plot(_get_selected_scan())
        _update_overview()

    def _update_single_scan_plot(scan):
        scan_motor = scan["scan_motor"]

        y = scan["counts"]
        x = scan[scan_motor]

        plot.axis[0].axis_label = scan_motor
        plot_scatter_source.data.update(x=x,
                                        y=y,
                                        y_upper=y + np.sqrt(y),
                                        y_lower=y - np.sqrt(y))

        fit = scan.get("fit")
        if fit is not None:
            x_fit = np.linspace(x[0], x[-1], 100)
            plot_fit_source.data.update(x=x_fit, y=fit.eval(x=x_fit))

            x_bkg = []
            y_bkg = []
            xs_peak = []
            ys_peak = []
            comps = fit.eval_components(x=x_fit)
            for i, model in enumerate(fit_params):
                if "linear" in model:
                    x_bkg = x_fit
                    y_bkg = comps[f"f{i}_"]

                elif any(val in model
                         for val in ("gaussian", "voigt", "pvoigt")):
                    xs_peak.append(x_fit)
                    ys_peak.append(comps[f"f{i}_"])

            plot_bkg_source.data.update(x=x_bkg, y=y_bkg)
            plot_peak_source.data.update(xs=xs_peak, ys=ys_peak)

            fit_output_textinput.value = fit.fit_report()

        else:
            plot_fit_source.data.update(x=[], y=[])
            plot_bkg_source.data.update(x=[], y=[])
            plot_peak_source.data.update(xs=[], ys=[])
            fit_output_textinput.value = ""

    def _update_overview():
        xs = []
        ys = []
        param = []
        x = []
        y = []
        par = []
        for s, p in enumerate(scan_table_source.data["param"]):
            if p is not None:
                scan = det_data[s]
                scan_motor = scan["scan_motor"]
                xs.append(scan[scan_motor])
                x.extend(scan[scan_motor])
                ys.append(scan["counts"])
                y.extend([float(p)] * len(scan[scan_motor]))
                param.append(float(p))
                par.extend(scan["counts"])

        if det_data:
            scan_motor = det_data[0]["scan_motor"]
            ov_plot.axis[0].axis_label = scan_motor
            ov_param_plot.axis[0].axis_label = scan_motor

        ov_plot_mline_source.data.update(xs=xs,
                                         ys=ys,
                                         param=param,
                                         color=color_palette(len(xs)))

        if y:
            mapper["transform"].low = np.min([np.min(y) for y in ys])
            mapper["transform"].high = np.max([np.max(y) for y in ys])
        ov_param_plot_scatter_source.data.update(x=x, y=y, param=par)

        if y:
            interp_f = interpolate.interp2d(x, y, par)
            x1, x2 = min(x), max(x)
            y1, y2 = min(y), max(y)
            image = interp_f(
                np.linspace(x1, x2, ov_param_plot.inner_width // 10),
                np.linspace(y1, y2, ov_param_plot.inner_height // 10),
                assume_sorted=True,
            )
            ov_param_plot_image_source.data.update(image=[image],
                                                   x=[x1],
                                                   y=[y1],
                                                   dw=[x2 - x1],
                                                   dh=[y2 - y1])
        else:
            ov_param_plot_image_source.data.update(image=[],
                                                   x=[],
                                                   y=[],
                                                   dw=[],
                                                   dh=[])

    def _update_param_plot():
        x = []
        y = []
        fit_param = fit_param_select.value
        for s, p in zip(det_data, scan_table_source.data["param"]):
            if "fit" in s and fit_param:
                x.append(p)
                y.append(s["fit"].values[fit_param])

        param_plot_scatter_source.data.update(x=x, y=y)

    # Main plot
    plot = Plot(
        x_range=DataRange1d(),
        y_range=DataRange1d(only_visible=True),
        plot_height=450,
        plot_width=700,
    )

    plot.add_layout(LinearAxis(axis_label="Counts"), place="left")
    plot.add_layout(LinearAxis(axis_label="Scan motor"), place="below")

    plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    plot_scatter_source = ColumnDataSource(
        dict(x=[0], y=[0], y_upper=[0], y_lower=[0]))
    plot_scatter = plot.add_glyph(
        plot_scatter_source, Scatter(x="x", y="y", line_color="steelblue"))
    plot.add_layout(
        Whisker(source=plot_scatter_source,
                base="x",
                upper="y_upper",
                lower="y_lower"))

    plot_fit_source = ColumnDataSource(dict(x=[0], y=[0]))
    plot_fit = plot.add_glyph(plot_fit_source, Line(x="x", y="y"))

    plot_bkg_source = ColumnDataSource(dict(x=[0], y=[0]))
    plot_bkg = plot.add_glyph(
        plot_bkg_source,
        Line(x="x", y="y", line_color="green", line_dash="dashed"))

    plot_peak_source = ColumnDataSource(dict(xs=[[0]], ys=[[0]]))
    plot_peak = plot.add_glyph(
        plot_peak_source,
        MultiLine(xs="xs", ys="ys", line_color="red", line_dash="dashed"))

    fit_from_span = Span(location=None, dimension="height", line_dash="dashed")
    plot.add_layout(fit_from_span)

    fit_to_span = Span(location=None, dimension="height", line_dash="dashed")
    plot.add_layout(fit_to_span)

    plot.add_layout(
        Legend(
            items=[
                ("data", [plot_scatter]),
                ("best fit", [plot_fit]),
                ("peak", [plot_peak]),
                ("linear", [plot_bkg]),
            ],
            location="top_left",
            click_policy="hide",
        ))

    plot.add_tools(PanTool(), WheelZoomTool(), ResetTool())
    plot.toolbar.logo = None

    # Overview multilines plot
    ov_plot = Plot(x_range=DataRange1d(),
                   y_range=DataRange1d(),
                   plot_height=450,
                   plot_width=700)

    ov_plot.add_layout(LinearAxis(axis_label="Counts"), place="left")
    ov_plot.add_layout(LinearAxis(axis_label="Scan motor"), place="below")

    ov_plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    ov_plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    ov_plot_mline_source = ColumnDataSource(
        dict(xs=[], ys=[], param=[], color=[]))
    ov_plot.add_glyph(ov_plot_mline_source,
                      MultiLine(xs="xs", ys="ys", line_color="color"))

    hover_tool = HoverTool(tooltips=[("param", "@param")])
    ov_plot.add_tools(PanTool(), WheelZoomTool(), hover_tool, ResetTool())

    ov_plot.add_tools(PanTool(), WheelZoomTool(), ResetTool())
    ov_plot.toolbar.logo = None

    # Overview perams plot
    ov_param_plot = Plot(x_range=DataRange1d(),
                         y_range=DataRange1d(),
                         plot_height=450,
                         plot_width=700)

    ov_param_plot.add_layout(LinearAxis(axis_label="Param"), place="left")
    ov_param_plot.add_layout(LinearAxis(axis_label="Scan motor"),
                             place="below")

    ov_param_plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    ov_param_plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    ov_param_plot_image_source = ColumnDataSource(
        dict(image=[], x=[], y=[], dw=[], dh=[]))
    ov_param_plot.add_glyph(
        ov_param_plot_image_source,
        Image(image="image", x="x", y="y", dw="dw", dh="dh"))

    ov_param_plot_scatter_source = ColumnDataSource(dict(x=[], y=[], param=[]))
    mapper = linear_cmap(field_name="param", palette=Turbo256, low=0, high=50)
    ov_param_plot.add_glyph(
        ov_param_plot_scatter_source,
        Scatter(x="x", y="y", line_color=mapper, fill_color=mapper, size=10),
    )

    ov_param_plot.add_tools(PanTool(), WheelZoomTool(), ResetTool())
    ov_param_plot.toolbar.logo = None

    # Parameter plot
    param_plot = Plot(x_range=DataRange1d(),
                      y_range=DataRange1d(),
                      plot_height=400,
                      plot_width=700)

    param_plot.add_layout(LinearAxis(axis_label="Fit parameter"), place="left")
    param_plot.add_layout(LinearAxis(axis_label="Parameter"), place="below")

    param_plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    param_plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    param_plot_scatter_source = ColumnDataSource(dict(x=[], y=[]))
    param_plot.add_glyph(param_plot_scatter_source, Scatter(x="x", y="y"))

    param_plot.add_tools(PanTool(), WheelZoomTool(), ResetTool())
    param_plot.toolbar.logo = None

    def fit_param_select_callback(_attr, _old, _new):
        _update_param_plot()

    fit_param_select = Select(title="Fit parameter", options=[], width=145)
    fit_param_select.on_change("value", fit_param_select_callback)

    # Plot tabs
    plots = Tabs(tabs=[
        Panel(child=plot, title="single scan"),
        Panel(child=ov_plot, title="overview"),
        Panel(child=ov_param_plot, title="overview map"),
        Panel(child=column(param_plot, row(fit_param_select)),
              title="parameter plot"),
    ])

    # Scan select
    def scan_table_select_callback(_attr, old, new):
        if not new:
            # skip empty selections
            return

        # Avoid selection of multiple indicies (via Shift+Click or Ctrl+Click)
        if len(new) > 1:
            # drop selection to the previous one
            scan_table_source.selected.indices = old
            return

        if len(old) > 1:
            # skip unnecessary update caused by selection drop
            return

        _update_plot()

    def scan_table_source_callback(_attr, _old, _new):
        _update_preview()

    scan_table_source = ColumnDataSource(
        dict(file=[], scan=[], param=[], fit=[], export=[]))
    scan_table_source.on_change("data", scan_table_source_callback)

    scan_table = DataTable(
        source=scan_table_source,
        columns=[
            TableColumn(field="file", title="file", width=150),
            TableColumn(field="scan", title="scan", width=50),
            TableColumn(field="param",
                        title="param",
                        editor=NumberEditor(),
                        width=50),
            TableColumn(field="fit", title="Fit", width=50),
            TableColumn(field="export",
                        title="Export",
                        editor=CheckboxEditor(),
                        width=50),
        ],
        width=410,  # +60 because of the index column
        editable=True,
        autosize_mode="none",
    )

    def scan_table_source_callback(_attr, _old, _new):
        if scan_table_source.selected.indices:
            _update_plot()

    scan_table_source.selected.on_change("indices", scan_table_select_callback)
    scan_table_source.on_change("data", scan_table_source_callback)

    def _get_selected_scan():
        return det_data[scan_table_source.selected.indices[0]]

    def param_select_callback(_attr, _old, new):
        if new == "user defined":
            param = [None] * len(det_data)
        else:
            param = [scan[new] for scan in det_data]

        scan_table_source.data["param"] = param
        _update_param_plot()

    param_select = Select(
        title="Parameter:",
        options=["user defined", "temp", "mf", "h", "k", "l"],
        value="user defined",
        width=145,
    )
    param_select.on_change("value", param_select_callback)

    def fit_from_spinner_callback(_attr, _old, new):
        fit_from_span.location = new

    fit_from_spinner = Spinner(title="Fit from:", width=145)
    fit_from_spinner.on_change("value", fit_from_spinner_callback)

    def fit_to_spinner_callback(_attr, _old, new):
        fit_to_span.location = new

    fit_to_spinner = Spinner(title="to:", width=145)
    fit_to_spinner.on_change("value", fit_to_spinner_callback)

    def fitparams_add_dropdown_callback(click):
        # bokeh requires (str, str) for MultiSelect options
        new_tag = f"{click.item}-{fitparams_select.tags[0]}"
        fitparams_select.options.append((new_tag, click.item))
        fit_params[new_tag] = fitparams_factory(click.item)
        fitparams_select.tags[0] += 1

    fitparams_add_dropdown = Dropdown(
        label="Add fit function",
        menu=[
            ("Linear", "linear"),
            ("Gaussian", "gaussian"),
            ("Voigt", "voigt"),
            ("Pseudo Voigt", "pvoigt"),
            # ("Pseudo Voigt1", "pseudovoigt1"),
        ],
        width=145,
    )
    fitparams_add_dropdown.on_click(fitparams_add_dropdown_callback)

    def fitparams_select_callback(_attr, old, new):
        # Avoid selection of multiple indicies (via Shift+Click or Ctrl+Click)
        if len(new) > 1:
            # drop selection to the previous one
            fitparams_select.value = old
            return

        if len(old) > 1:
            # skip unnecessary update caused by selection drop
            return

        if new:
            fitparams_table_source.data.update(fit_params[new[0]])
        else:
            fitparams_table_source.data.update(
                dict(param=[], value=[], vary=[], min=[], max=[]))

    fitparams_select = MultiSelect(options=[], height=120, width=145)
    fitparams_select.tags = [0]
    fitparams_select.on_change("value", fitparams_select_callback)

    def fitparams_remove_button_callback():
        if fitparams_select.value:
            sel_tag = fitparams_select.value[0]
            del fit_params[sel_tag]
            for elem in fitparams_select.options:
                if elem[0] == sel_tag:
                    fitparams_select.options.remove(elem)
                    break

            fitparams_select.value = []

    fitparams_remove_button = Button(label="Remove fit function", width=145)
    fitparams_remove_button.on_click(fitparams_remove_button_callback)

    def fitparams_factory(function):
        if function == "linear":
            params = ["slope", "intercept"]
        elif function == "gaussian":
            params = ["amplitude", "center", "sigma"]
        elif function == "voigt":
            params = ["amplitude", "center", "sigma", "gamma"]
        elif function == "pvoigt":
            params = ["amplitude", "center", "sigma", "fraction"]
        elif function == "pseudovoigt1":
            params = ["amplitude", "center", "g_sigma", "l_sigma", "fraction"]
        else:
            raise ValueError("Unknown fit function")

        n = len(params)
        fitparams = dict(
            param=params,
            value=[None] * n,
            vary=[True] * n,
            min=[None] * n,
            max=[None] * n,
        )

        if function == "linear":
            fitparams["value"] = [0, 1]
            fitparams["vary"] = [False, True]
            fitparams["min"] = [None, 0]

        elif function == "gaussian":
            fitparams["min"] = [0, None, None]

        return fitparams

    fitparams_table_source = ColumnDataSource(
        dict(param=[], value=[], vary=[], min=[], max=[]))
    fitparams_table = DataTable(
        source=fitparams_table_source,
        columns=[
            TableColumn(field="param", title="Parameter"),
            TableColumn(field="value", title="Value", editor=NumberEditor()),
            TableColumn(field="vary", title="Vary", editor=CheckboxEditor()),
            TableColumn(field="min", title="Min", editor=NumberEditor()),
            TableColumn(field="max", title="Max", editor=NumberEditor()),
        ],
        height=200,
        width=350,
        index_position=None,
        editable=True,
        auto_edit=True,
    )

    # start with `background` and `gauss` fit functions added
    fitparams_add_dropdown_callback(types.SimpleNamespace(item="linear"))
    fitparams_add_dropdown_callback(types.SimpleNamespace(item="gaussian"))
    fitparams_select.value = ["gaussian-1"]  # add selection to gauss

    fit_output_textinput = TextAreaInput(title="Fit results:",
                                         width=750,
                                         height=200)

    def proc_all_button_callback():
        for scan, export in zip(det_data, scan_table_source.data["export"]):
            if export:
                pyzebra.fit_scan(scan,
                                 fit_params,
                                 fit_from=fit_from_spinner.value,
                                 fit_to=fit_to_spinner.value)
                pyzebra.get_area(
                    scan,
                    area_method=AREA_METHODS[area_method_radiobutton.active],
                    lorentz=lorentz_checkbox.active,
                )

        _update_plot()
        _update_table()

        for scan in det_data:
            if "fit" in scan:
                options = list(scan["fit"].params.keys())
                fit_param_select.options = options
                fit_param_select.value = options[0]
                break
        _update_param_plot()

    proc_all_button = Button(label="Process All",
                             button_type="primary",
                             width=145)
    proc_all_button.on_click(proc_all_button_callback)

    def proc_button_callback():
        scan = _get_selected_scan()
        pyzebra.fit_scan(scan,
                         fit_params,
                         fit_from=fit_from_spinner.value,
                         fit_to=fit_to_spinner.value)
        pyzebra.get_area(
            scan,
            area_method=AREA_METHODS[area_method_radiobutton.active],
            lorentz=lorentz_checkbox.active,
        )

        _update_plot()
        _update_table()

        for scan in det_data:
            if "fit" in scan:
                options = list(scan["fit"].params.keys())
                fit_param_select.options = options
                fit_param_select.value = options[0]
                break
        _update_param_plot()

    proc_button = Button(label="Process Current", width=145)
    proc_button.on_click(proc_button_callback)

    area_method_div = Div(text="Intensity:", margin=(5, 5, 0, 5))
    area_method_radiobutton = RadioGroup(labels=["Function", "Area"],
                                         active=0,
                                         width=145)

    lorentz_checkbox = CheckboxGroup(labels=["Lorentz Correction"],
                                     width=145,
                                     margin=(13, 5, 5, 5))

    export_preview_textinput = TextAreaInput(title="Export file preview:",
                                             width=450,
                                             height=400)

    def _update_preview():
        with tempfile.TemporaryDirectory() as temp_dir:
            temp_file = temp_dir + "/temp"
            export_data = []
            for s, export in zip(det_data, scan_table_source.data["export"]):
                if export:
                    export_data.append(s)

            # pyzebra.export_1D(export_data, temp_file, "fullprof")

            exported_content = ""
            file_content = []
            for ext in (".comm", ".incomm"):
                fname = temp_file + ext
                if os.path.isfile(fname):
                    with open(fname) as f:
                        content = f.read()
                        exported_content += f"{ext} file:\n" + content
                else:
                    content = ""
                file_content.append(content)

            js_data.data.update(content=file_content)
            export_preview_textinput.value = exported_content

    save_button = Button(label="Download File(s)",
                         button_type="success",
                         width=220)
    save_button.js_on_click(
        CustomJS(args={"js_data": js_data}, code=javaScript))

    fitpeak_controls = row(
        column(fitparams_add_dropdown, fitparams_select,
               fitparams_remove_button),
        fitparams_table,
        Spacer(width=20),
        column(fit_from_spinner, lorentz_checkbox, area_method_div,
               area_method_radiobutton),
        column(fit_to_spinner, proc_button, proc_all_button),
    )

    scan_layout = column(scan_table,
                         row(monitor_spinner, scan_motor_select, param_select))

    import_layout = column(
        proposal_textinput,
        file_select,
        row(file_open_button, file_append_button),
        upload_div,
        upload_button,
        append_upload_div,
        append_upload_button,
    )

    export_layout = column(export_preview_textinput, row(save_button))

    tab_layout = column(
        row(import_layout, scan_layout, plots, Spacer(width=30),
            export_layout),
        row(fitpeak_controls, fit_output_textinput),
    )

    return Panel(child=tab_layout, title="param study")
Exemplo n.º 26
0
def create(palm):
    energy_min = palm.energy_range.min()
    energy_max = palm.energy_range.max()
    energy_npoints = palm.energy_range.size

    current_results = (0, 0, 0, 0)

    doc = curdoc()

    # Streaked and reference waveforms plot
    waveform_plot = Plot(
        title=Title(text="eTOF waveforms"),
        x_range=DataRange1d(),
        y_range=DataRange1d(),
        plot_height=PLOT_CANVAS_HEIGHT,
        plot_width=PLOT_CANVAS_WIDTH,
        toolbar_location="right",
    )

    # ---- tools
    waveform_plot.toolbar.logo = None
    waveform_plot.add_tools(PanTool(), BoxZoomTool(), WheelZoomTool(),
                            ResetTool())

    # ---- axes
    waveform_plot.add_layout(LinearAxis(axis_label="Photon energy, eV"),
                             place="below")
    waveform_plot.add_layout(LinearAxis(axis_label="Intensity",
                                        major_label_orientation="vertical"),
                             place="left")

    # ---- grid lines
    waveform_plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    waveform_plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- line glyphs
    waveform_source = ColumnDataSource(
        dict(x_str=[], y_str=[], x_ref=[], y_ref=[]))
    waveform_ref_line = waveform_plot.add_glyph(
        waveform_source, Line(x="x_ref", y="y_ref", line_color="blue"))
    waveform_str_line = waveform_plot.add_glyph(
        waveform_source, Line(x="x_str", y="y_str", line_color="red"))

    # ---- legend
    waveform_plot.add_layout(
        Legend(items=[("reference",
                       [waveform_ref_line]), ("streaked",
                                              [waveform_str_line])]))
    waveform_plot.legend.click_policy = "hide"

    # Cross-correlation plot
    xcorr_plot = Plot(
        title=Title(text="Waveforms cross-correlation"),
        x_range=DataRange1d(),
        y_range=DataRange1d(),
        plot_height=PLOT_CANVAS_HEIGHT,
        plot_width=PLOT_CANVAS_WIDTH,
        toolbar_location="right",
    )

    # ---- tools
    xcorr_plot.toolbar.logo = None
    xcorr_plot.add_tools(PanTool(), BoxZoomTool(), WheelZoomTool(),
                         ResetTool())

    # ---- axes
    xcorr_plot.add_layout(LinearAxis(axis_label="Energy shift, eV"),
                          place="below")
    xcorr_plot.add_layout(LinearAxis(axis_label="Cross-correlation",
                                     major_label_orientation="vertical"),
                          place="left")

    # ---- grid lines
    xcorr_plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    xcorr_plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- line glyphs
    xcorr_source = ColumnDataSource(dict(lags=[], xcorr1=[], xcorr2=[]))
    xcorr_plot.add_glyph(
        xcorr_source,
        Line(x="lags", y="xcorr1", line_color="purple", line_dash="dashed"))
    xcorr_plot.add_glyph(xcorr_source,
                         Line(x="lags", y="xcorr2", line_color="purple"))

    # ---- vertical span
    xcorr_center_span = Span(location=0, dimension="height")
    xcorr_plot.add_layout(xcorr_center_span)

    # Delays plot
    pulse_delay_plot = Plot(
        title=Title(text="Pulse delays"),
        x_range=DataRange1d(),
        y_range=DataRange1d(),
        plot_height=PLOT_CANVAS_HEIGHT,
        plot_width=PLOT_CANVAS_WIDTH,
        toolbar_location="right",
    )

    # ---- tools
    pulse_delay_plot.toolbar.logo = None
    pulse_delay_plot.add_tools(PanTool(), BoxZoomTool(), WheelZoomTool(),
                               ResetTool())

    # ---- axes
    pulse_delay_plot.add_layout(LinearAxis(axis_label="Pulse number"),
                                place="below")
    pulse_delay_plot.add_layout(
        LinearAxis(axis_label="Pulse delay (uncalib), eV",
                   major_label_orientation="vertical"),
        place="left",
    )

    # ---- grid lines
    pulse_delay_plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    pulse_delay_plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- line glyphs
    pulse_delay_source = ColumnDataSource(dict(pulse=[], delay=[]))
    pulse_delay_plot.add_glyph(
        pulse_delay_source, Line(x="pulse", y="delay", line_color="steelblue"))

    # ---- vertical span
    pulse_delay_plot_span = Span(location=0, dimension="height")
    pulse_delay_plot.add_layout(pulse_delay_plot_span)

    # Pulse lengths plot
    pulse_length_plot = Plot(
        title=Title(text="Pulse lengths"),
        x_range=DataRange1d(),
        y_range=DataRange1d(),
        plot_height=PLOT_CANVAS_HEIGHT,
        plot_width=PLOT_CANVAS_WIDTH,
        toolbar_location="right",
    )

    # ---- tools
    pulse_length_plot.toolbar.logo = None
    pulse_length_plot.add_tools(PanTool(), BoxZoomTool(), WheelZoomTool(),
                                ResetTool())

    # ---- axes
    pulse_length_plot.add_layout(LinearAxis(axis_label="Pulse number"),
                                 place="below")
    pulse_length_plot.add_layout(
        LinearAxis(axis_label="Pulse length (uncalib), eV",
                   major_label_orientation="vertical"),
        place="left",
    )

    # ---- grid lines
    pulse_length_plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    pulse_length_plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- line glyphs
    pulse_length_source = ColumnDataSource(dict(x=[], y=[]))
    pulse_length_plot.add_glyph(pulse_length_source,
                                Line(x="x", y="y", line_color="steelblue"))

    # ---- vertical span
    pulse_length_plot_span = Span(location=0, dimension="height")
    pulse_length_plot.add_layout(pulse_length_plot_span)

    # Folder path text input
    def path_textinput_callback(_attr, _old_value, new_value):
        save_textinput.value = new_value
        path_periodic_update()

    path_textinput = TextInput(title="Folder Path:",
                               value=os.path.join(os.path.expanduser("~")),
                               width=510)
    path_textinput.on_change("value", path_textinput_callback)

    # Saved runs dropdown menu
    def h5_update(pulse, delays, debug_data):
        prep_data, lags, corr_res_uncut, corr_results = debug_data

        waveform_source.data.update(
            x_str=palm.energy_range,
            y_str=prep_data["1"][pulse, :],
            x_ref=palm.energy_range,
            y_ref=prep_data["0"][pulse, :],
        )

        xcorr_source.data.update(lags=lags,
                                 xcorr1=corr_res_uncut[pulse, :],
                                 xcorr2=corr_results[pulse, :])

        xcorr_center_span.location = delays[pulse]
        pulse_delay_plot_span.location = pulse
        pulse_length_plot_span.location = pulse

    # this placeholder function should be reassigned in 'saved_runs_dropdown_callback'
    h5_update_fun = lambda pulse: None

    def saved_runs_dropdown_callback(_attr, _old_value, new_value):
        if new_value != "Saved Runs":
            nonlocal h5_update_fun, current_results
            saved_runs_dropdown.label = new_value
            filepath = os.path.join(path_textinput.value, new_value)
            tags, delays, lengths, debug_data = palm.process_hdf5_file(
                filepath, debug=True)
            current_results = (new_value, tags, delays, lengths)

            if autosave_checkbox.active:
                save_button_callback()

            pulse_delay_source.data.update(pulse=np.arange(len(delays)),
                                           delay=delays)
            pulse_length_source.data.update(x=np.arange(len(lengths)),
                                            y=lengths)
            h5_update_fun = partial(h5_update,
                                    delays=delays,
                                    debug_data=debug_data)

            pulse_slider.end = len(delays) - 1
            pulse_slider.value = 0
            h5_update_fun(0)

    saved_runs_dropdown = Dropdown(label="Saved Runs",
                                   button_type="primary",
                                   menu=[])
    saved_runs_dropdown.on_change("value", saved_runs_dropdown_callback)

    # ---- saved run periodic update
    def path_periodic_update():
        new_menu = []
        if os.path.isdir(path_textinput.value):
            for entry in os.scandir(path_textinput.value):
                if entry.is_file() and entry.name.endswith((".hdf5", ".h5")):
                    new_menu.append((entry.name, entry.name))
        saved_runs_dropdown.menu = sorted(new_menu, reverse=True)

    doc.add_periodic_callback(path_periodic_update, 5000)

    # Pulse number slider
    def pulse_slider_callback(_attr, _old_value, new_value):
        h5_update_fun(pulse=new_value)

    pulse_slider = Slider(
        start=0,
        end=99999,
        value=0,
        step=1,
        title="Pulse ID",
        callback_policy="throttle",
        callback_throttle=500,
    )
    pulse_slider.on_change("value", pulse_slider_callback)

    # Energy maximal range value text input
    def energy_max_spinner_callback(_attr, old_value, new_value):
        nonlocal energy_max
        if new_value > energy_min:
            energy_max = new_value
            palm.energy_range = np.linspace(energy_min, energy_max,
                                            energy_npoints)
            saved_runs_dropdown_callback("", "", saved_runs_dropdown.label)
        else:
            energy_max_spinner.value = old_value

    energy_max_spinner = Spinner(title="Maximal Energy, eV:",
                                 value=energy_max,
                                 step=0.1)
    energy_max_spinner.on_change("value", energy_max_spinner_callback)

    # Energy minimal range value text input
    def energy_min_spinner_callback(_attr, old_value, new_value):
        nonlocal energy_min
        if new_value < energy_max:
            energy_min = new_value
            palm.energy_range = np.linspace(energy_min, energy_max,
                                            energy_npoints)
            saved_runs_dropdown_callback("", "", saved_runs_dropdown.label)
        else:
            energy_min_spinner.value = old_value

    energy_min_spinner = Spinner(title="Minimal Energy, eV:",
                                 value=energy_min,
                                 step=0.1)
    energy_min_spinner.on_change("value", energy_min_spinner_callback)

    # Energy number of interpolation points text input
    def energy_npoints_spinner_callback(_attr, old_value, new_value):
        nonlocal energy_npoints
        if new_value > 1:
            energy_npoints = new_value
            palm.energy_range = np.linspace(energy_min, energy_max,
                                            energy_npoints)
            saved_runs_dropdown_callback("", "", saved_runs_dropdown.label)
        else:
            energy_npoints_spinner.value = old_value

    energy_npoints_spinner = Spinner(title="Number of interpolation points:",
                                     value=energy_npoints)
    energy_npoints_spinner.on_change("value", energy_npoints_spinner_callback)

    # Save location
    save_textinput = TextInput(title="Save Folder Path:",
                               value=os.path.join(os.path.expanduser("~")))

    # Autosave checkbox
    autosave_checkbox = CheckboxButtonGroup(labels=["Auto Save"],
                                            active=[],
                                            width=250)

    # Save button
    def save_button_callback():
        if current_results[0]:
            filename, tags, delays, lengths = current_results
            save_filename = os.path.splitext(filename)[0] + ".csv"
            df = pd.DataFrame({
                "pulse_id": tags,
                "pulse_delay": delays,
                "pulse_length": lengths
            })
            df.to_csv(os.path.join(save_textinput.value, save_filename),
                      index=False)

    save_button = Button(label="Save Results",
                         button_type="default",
                         width=250)
    save_button.on_click(save_button_callback)

    # assemble
    tab_layout = column(
        row(
            column(waveform_plot, xcorr_plot),
            Spacer(width=30),
            column(
                path_textinput,
                saved_runs_dropdown,
                pulse_slider,
                Spacer(height=30),
                energy_min_spinner,
                energy_max_spinner,
                energy_npoints_spinner,
                Spacer(height=30),
                save_textinput,
                autosave_checkbox,
                save_button,
            ),
        ),
        row(pulse_delay_plot, Spacer(width=10), pulse_length_plot),
    )

    return Panel(child=tab_layout, title="HDF5 File")
Exemplo n.º 27
0
    def __init__(self,
                 nplots,
                 plot_height=200,
                 plot_width=1000,
                 rollover=10800,
                 mode="time"):
        """Initialize stream graph plots.

        Args:
            nplots (int): Number of stream plots that will share common controls.
            plot_height (int, optional): Height of plot area in screen pixels. Defaults to 200.
            plot_width (int, optional): Width of plot area in screen pixels. Defaults to 1000.
            rollover (int, optional): A maximum number of points, above which data from the start
                begins to be discarded. If None, then graph will grow unbounded. Defaults to 10800.
            mode (str, optional): stream update mode, 'time' - uses the local wall time,
                'number' - uses a image number counter. Defaults to 'time'.
        """
        self.rollover = rollover
        self.mode = mode
        self._stream_t = 0
        self._buffers = []
        self._window = 30

        # Custom tick formatter for displaying large numbers
        tick_formatter = BasicTickFormatter(precision=1)

        # Stream graphs
        self.plots = []
        self.glyphs = []
        self._sources = []
        for ind in range(nplots):
            # share x_range between plots
            if ind == 0:
                x_range = DataRange1d()

            plot = Plot(
                x_range=x_range,
                y_range=DataRange1d(),
                plot_height=plot_height,
                plot_width=plot_width,
            )

            # ---- tools
            plot.toolbar.logo = None
            plot.add_tools(PanTool(), BoxZoomTool(),
                           WheelZoomTool(dimensions="width"), ResetTool())

            # ---- axes
            plot.add_layout(LinearAxis(formatter=tick_formatter), place="left")
            if mode == "time":
                plot.add_layout(DatetimeAxis(), place="below")
            elif mode == "number":
                plot.add_layout(LinearAxis(), place="below")
            else:
                pass

            # ---- grid lines
            plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
            plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

            # ---- line glyph
            source = ColumnDataSource(dict(x=[], y=[], x_avg=[], y_avg=[]))
            line = Line(x="x", y="y", line_color="gray")
            line_avg = Line(x="x_avg", y="y_avg", line_color="red")
            line_renderer = plot.add_glyph(source, line)
            line_avg_renderer = plot.add_glyph(source, line_avg)

            # ---- legend
            plot.add_layout(
                Legend(
                    items=[("per frame", [line_renderer]),
                           ("moving average", [line_avg_renderer])],
                    location="top_left",
                ))
            plot.legend.click_policy = "hide"

            self.plots.append(plot)
            self.glyphs.append(line)
            self._sources.append(source)
            self._buffers.append(deque(maxlen=MAXLEN))

        # Moving average spinner
        def moving_average_spinner_callback(_attr, _old_value, new_value):
            if moving_average_spinner.low <= new_value <= moving_average_spinner.high:
                self._window = new_value

        moving_average_spinner = Spinner(title="Moving Average Window:",
                                         value=self._window,
                                         low=1,
                                         high=MAXLEN,
                                         default_size=145)
        moving_average_spinner.on_change("value",
                                         moving_average_spinner_callback)
        self.moving_average_spinner = moving_average_spinner

        # Reset button
        def reset_button_callback():
            # keep the latest point in order to prevent a full axis reset
            if mode == "time":
                pass  # update with the lastest time
            elif mode == "number":
                self._stream_t = 1

            for source in self._sources:
                if source.data["x"]:
                    source.data.update(
                        x=[self._stream_t],
                        y=[source.data["y"][-1]],
                        x_avg=[self._stream_t],
                        y_avg=[source.data["y_avg"][-1]],
                    )

        reset_button = Button(label="Reset",
                              button_type="default",
                              default_size=145)
        reset_button.on_click(reset_button_callback)
        self.reset_button = reset_button