示例#1
0
 def _add_multiselect(self):
     self.multiselect = MultiSelect(title='States:', value=['01'],
                                    options=self.cases.options)
     self.multiselect.max_width = 170
     self.multiselect.min_height = 500 - 47
     self.multiselect.on_change('value', self._callback_cases)
     self.multiselect.on_change('value', self._callback_deaths)
示例#2
0
def create_multiselect(cont):
    """
    Create a multiselect box for each continent, for the "All" tab.
    Args:
        cont (str): Continent name.
    Returns:
        multi_select (:obj:`MultiSelect`): Multiselect box.
    """
    multi_select = MultiSelect(title=cont, value=[], options=by_cont[cont])
    multi_select.on_change("value", multi_update)
    return multi_select
示例#3
0
    def layout_hook(self):
        """ Hook for layout creation. """
        options = [(v, v) for v in np.unique(self.context.data[self.coord])]

        layout = MultiSelect(title=self.coord,
                             value=[options[0][0]],
                             options=options)
        layout.size = min(len(options), self.max_elements)
        layout.on_change("value", self.on_selected_coord_change)

        self.coord_vals = [options[0][0]]
        self.context._update_handlers()

        return layout
示例#4
0
    def document(doc):
        sliders = []
        pre = PreText(text='something')
        select = MultiSelect(title='Select the columns',
                             options=list(df.columns))
        button = Button(label='Update')
        for col in cols:
            MIN, MAX = df[col].min(), df[col].max()
            STEP = (MAX - MIN) / 100
            slider = RangeSlider(start=MIN,
                                 end=MAX,
                                 step=STEP,
                                 value=(MIN, MAX),
                                 title=col)
            sliders.append(slider)

        def update():
            values = []
            txt = '({col} > {low}) & ({col} < {high})'
            for slider in sliders:
                low, high = slider.value
                low, high = float(low), float(high)
                formula = txt.format(col=slider.title, low=low, high=high)
                values.append(formula)
            q = '&'.join(values)
            summary_cols = select.value
            text = df.query(q)[summary_cols].describe()
            pre.text = str(text)

        button.on_click(update)

        l = layout([column(sliders), [select, button]], pre)
        doc.add_root(l)
def update_dataset(attr, old, new):
    selected_dataset = dataset.value
    rank_slice.end = data_manager.get_size(selected_dataset)
    rank_slice.update(value=(1, data_manager.get_size(selected_dataset)))

    metadata_type = data_manager.get_metadata_type(selected_dataset)
    metadata_domain = data_manager.get_metadata_domain(selected_dataset)

    while len(metadata_filters) > 0:
        metadata_filters.pop()
    for attribute in metadata_type:
        m_type = metadata_type[attribute]
        m_domain = metadata_domain[attribute]
        if m_type == 'boolean':
            filter = Select(title=attribute,
                            value="Any",
                            options=["Any", "True", "False"])
        elif m_type == 'numerical':
            filter = RangeSlider(start=m_domain[0],
                                 end=m_domain[1],
                                 value=m_domain,
                                 step=1,
                                 title=attribute)
        elif m_type == 'categorical':
            categories = sorted(list(metadata_domain[attribute]))
            filter = MultiSelect(title=attribute,
                                 value=categories,
                                 options=categories)
        elif m_type == 'set':
            categories = sorted(list(metadata_domain[attribute]))
            filter = MultiSelect(title=attribute,
                                 value=categories,
                                 options=categories)
        else:
            raise ValueError(
                'Unsupported attribute type {} in metadata'.format(m_type))
        metadata_filters.append(filter)

    for control in metadata_filters:
        if hasattr(control, 'value'):
            control.on_change('value', update)
        if hasattr(control, 'active'):
            control.on_change('active', update)

    inputs.children = build_controls()

    update(attr, old, new)
示例#6
0
    def __init__(self):

        self.rubric = pd.read_excel(
            os.path.join(file_path, "data/Rubric.xlsx"), "Rubric v3")
        self.cost_model = pd.read_excel(
            os.path.join(file_path, "data/Rubric.xlsx"), "Cost_Model")

        try:

            self.rubric.drop(["Category", "Definition", "Grading Scale"],
                             inplace=True,
                             axis=1)

        except KeyError:

            pass

        self.criteria = self.rubric["Criteria"].drop_duplicates().tolist()

        self.swing_table = swing_table.create_swing_table()

        self.chosen_criteria = []

        self.criteria_selection = MultiSelect(title="Choose Criteria:",
                                              size=10)
        self.choose_criteria()

        self.rubric_values = self.rubric.replace("Excellent", 1.0)
        self.rubric_values.replace("Good", 0.5, inplace=True)
        self.rubric_values.replace("Poor", 0, inplace=True)

        self.rubric_values = self.rubric_values.melt(id_vars=["Criteria"],
                                                     var_name=["Tool"],
                                                     value_name="Score")

        self.weight_sliders = OrderedDict()
        self.ranking = OrderedDict()

        self.b = Button(label="Update Model", button_type="primary")
        self.b.on_click(self.submit_callback)

        self.criteria_b = Button(label="Submit Criteria",
                                 button_type="primary")
        self.criteria_b.on_click(self.choose_criteria_callback)

        self.clear_button = Button(label="Reset", button_type="warning")
        self.clear_button.on_click(self.clear_model)

        self.rank_submit = Button(label="Calculate Ranks",
                                  button_type="primary")
        self.rank_submit.on_click(self.submit_ranks)

        self.source = ColumnDataSource()

        self.data_table = DataTable

        self.app_layout = layout()
示例#7
0
def timeplot(data):
    #input data is a DataFrame
    time = pd.DatetimeIndex(data['ltime'])
    #String list to store column names from the third column of the dataframe
    columns = []
    for x in data.columns[1:]:
        columns.append(x)
    #change string to float in the data
    for x in columns[0:(len(columns) - 2)]:
        if (type(data[x][0]) is str):
            for i in range(len(data[x])):
                data[x][i] = float(data[x][i].replace(',', ''))
    output_notebook()
    y = data[columns[1]]
    x = time
    p = Figure(x_axis_type='datetime', title="TimeSeries Plotting")

    source = ColumnDataSource(data=dict(x=x, y=y, d=data))
    #create a new columndatasoure to pass column name to CustomJS
    source2 = ColumnDataSource(data=dict(columns=columns))

    p.line('x', 'y', source=source)
    p.xaxis.axis_label = "Time"
    p.yaxis.axis_label = "Selected Y"
    callback = CustomJS(args=dict(source=source, columns=source2),
                        code="""
				var data = source.get('data');
				var columns = columns.get('data')['columns'];
				var f = cb_obj.get('value');
				y = data['y'];
				console.log('y');
				console.log(y);
				var d = data['d'];
				//get the index of the chosen column from the widget

				for(i = 0; i<columns.length;i++){
					if(f[0]==columns[i]){
					index = i;
					}
				}
				//make the column transpose since the CustomJS
				//takes dataframe as an array of arrays which is 
				//a row of the DataFrame
				for (i = 0; i < d.length; i++) {
					y[i] = d[i][index+1];
				}
				console.log('y');
				console.log(y.length);
				source.trigger('change');
				""")

    select = MultiSelect(title="Y_Option:",
                         value=[columns[0]],
                         options=columns,
                         callback=callback)
    layout = vform(select, p)
    show(layout)
示例#8
0
    def __init__(
        self,
        name="Specials",
        descr="Choose one",
        kind="specials",
        css_classes=[],
        entries={},
        default="",
        title=None,
        none_allowed=False,
    ):
        self.name = name
        self.descr = descr
        self.entries = entries
        self.kind = kind
        self.css_classes = css_classes
        options = sorted(entries.keys())
        if none_allowed:
            options = ["None"] + options
        if title is None:
            title = "."
            css_classes = ["deli-selector", "hide-title"]
        else:
            css_classes = ["deli-selector"]
        self.widget = MultiSelect(
            options=options,
            value=[default],
            # height=150,
            size=8,
            name="deli-selector",
            title=title,
            css_classes=css_classes,
        )

        # HACK: force MultiSelect to only have 1 value selected
        def multi_select_hack(attr, old, new):
            if len(new) > 1:
                self.widget.value = old

        self.widget.on_change("value", multi_select_hack)
示例#9
0
    def modify_doc(self, doc):

        self.source = ColumnDataSource(data=dict(top=[], left=[], right=[]))

        self.kplot = figure()
        self.kplot.y_range.start = 0
        self.kplot.quad(top="top",
                        bottom=0,
                        left="left",
                        right="right",
                        source=self.source)

        desc = Div(text=description, sizing_mode="stretch_width")

        self.sel_procs = MultiSelect(title="Select MPI.OpenMP",
                                     options=self.procs,
                                     value=["All"])
        self.sel_plots = Select(title="Select Plot Type",
                                options=self.plots,
                                value=self.plots[0])
        self.sel_invokes = MultiSelect(title="Select invokes",
                                       options=self.invokes,
                                       value=["All"])

        controls = [self.sel_procs, self.sel_invokes, self.sel_plots]

        for control in controls:
            control.on_change("value", lambda attr, old, new: self.update())

        root = column(
            desc,
            row(self.sel_procs,
                column(self.sel_plots, self.kplot, self.sel_invokes)))

        self.update()

        doc.add_root(root)
        doc.title = "Kernel Timing"
示例#10
0
    def document(doc):
        condition = TextInput(title='Enter your condition')
        col = MultiSelect(title='Select the columns', options=list(df.columns))
        button = Button(label='Update')
        pre = PreText()
        pre.text = condition.value

        def update():
            cond = condition.value
            cols = col.value
            text = df.query(cond)[cols].describe()
            pre.text = str(text)

        button.on_click(update)

        l = layout([[[condition, button], col], pre])
        doc.add_root(l)
示例#11
0
def count_clients():

    df_clients = pd.read_sql("clients", db.engine)
    df_users = pd.read_sql("users", db.engine)
    df = pd.merge(
        df_clients,
        df_users[["id", "username"]],
        how="left",
        left_on="user_id",
        right_on="id",
    )
    df = df.groupby("username").count()
    print(df)

    options = list(df.index)

    plot = figure(x_range=options)
    plot.vbar(x=options, top=df.first_name, width=0.75)

    select = MultiSelect(options=options, value=[options[0]])
    layout = column(select, plot)

    return df, layout
示例#12
0
# intialize widgets
save_button = Button(label="Save flagged data", button_type="success")
parameter = Select(title="Parameter",
                   options=["CTDSAL", "CTDTMP"],
                   value="CTDSAL")
ref_param = Select(title="Reference", options=["SALNTY"], value="SALNTY")
# ref_param.options = ["foo","bar"]  # can dynamically change dropdowns
station = Select(title="Station")
# explanation of flags:
# https://cchdo.github.io/hdo-assets/documentation/manuals/pdf/90_1/chap4.pdf
flag_list = MultiSelect(
    title="Plot data flagged as:",
    value=["1", "2", "3"],
    options=[
        ("1", "1 [Uncalibrated]"),
        ("2", "2 [Acceptable]"),
        ("3", "3 [Questionable]"),
        ("4", "4 [Bad]"),
    ],
)
# returns list of select options, e.g., ['2'] or ['1','2']
flag_input = Select(
    title="Flag:",
    options=[
        ("1", "1 [Uncalibrated]"),
        ("2", "2 [Acceptable]"),
        ("3", "3 [Questionable]"),
        ("4", "4 [Bad]"),
    ],
    value="3",
)
def plot(tables, output_filename):
    '''
    This is the plot function that uses Bokeh functions and widgets to make an interactive hexagon plot.

    This function recieves:
    - tables: dictionary with tables used to create arrays of repeated x, y coordinates (depending on the counts) for the hexagon plot.
    - output_filename: filename of .html output in the plots folder

    The coordinate arrays are used to create a pandas dataframe with Bokeh functions. This dataframe contains the q, r coordinates and counts used to plot the
    hexagons. To this dataframe, extra information is added (e.g. most common chemicals), which is displayed in the hover tooltip.

    Gaussian blur is added to copies of this dataframe and given as input to the Bokeh slider widget.
    Other widgets are added as well, for saturation, normalisation etc. Bokeh allows to customize these widges with javascript code.

    The hexagon plot is saved as a .html file and also shown in the browser.
    '''

    file_name = 'plots/' + str(output_filename) + '.html'
    output_file(file_name)

    # Blur and saturation values
    BLUR_MAX = 3
    BLUR_STEP_SIZE = 1
    SATURATION_MAX = 5
    SATURATION_STEP_SIZE = 0.25

    # First, create array for plot properties ( ratio, size of hexagons etc.)
    default_term = list(tables.keys())[0]
    x, y, ids = create_array(tables[default_term]['table'],
                             normalisation=False)

    # Hexagon plot properties
    length = len(x)
    orientation = 'flattop'
    ratio = ((max(y) - min(y)) / (max(x) - min(x)))
    size = 10 / ratio
    h = sqrt(3) * size
    h = h * ratio
    title = 'Hexbin plot for ' + str(
        length) + ' annotated chemicals with query ' + str(default_term)

    # make figure
    p = figure(title=title,
               x_range=[min(x) - 0.5, max(x) + 0.5],
               y_range=[0 - (h / 2), max(y) + 100],
               tools="wheel_zoom,reset,save",
               background_fill_color='#440154')

    p.grid.visible = False
    p.xaxis.axis_label = "log(P)"
    p.yaxis.axis_label = "mass in Da"
    p.xaxis.axis_label_text_font_style = 'normal'
    p.yaxis.axis_label_text_font_style = 'normal'

    # source for plot
    term_to_source, term_to_metadata, options = make_plot_sources(
        tables, size, ratio, orientation, BLUR_MAX, BLUR_STEP_SIZE)

    # start source for plot, this is the source that is first displayed in the hexagon figure
    x, y, ids = create_array(tables[default_term]['table'],
                             normalisation=False)
    df = hexbin(x, y, ids, size, aspect_scale=ratio, orientation=orientation)
    df = add_counts(df, tables[default_term]['table'])
    source = ColumnDataSource(df)
    metadata = term_to_metadata[default_term]
    metadata = return_html(metadata)

    # color mapper
    mapper = linear_cmap('scaling', 'Viridis256', 0,
                         max(source.data['scaling']))

    # plot
    hex = p.hex_tile(q="q",
                     r="r",
                     size=size,
                     line_color=None,
                     source=source,
                     aspect_scale=ratio,
                     orientation=orientation,
                     fill_color=mapper)

    # HOVER
    TOOLTIPS = return_tooltip()
    code_callback_hover = return_code('hover')
    callback_hover = CustomJS(code=code_callback_hover)
    hover = HoverTool(tooltips=TOOLTIPS,
                      callback=callback_hover,
                      show_arrow=False)
    p.add_tools(hover)

    # WIDGETS
    slider1 = Slider(start=1,
                     end=SATURATION_MAX,
                     value=1,
                     step=SATURATION_STEP_SIZE,
                     title="Saturation",
                     width=100)
    slider2 = Slider(start=0,
                     end=BLUR_MAX,
                     value=0,
                     step=BLUR_STEP_SIZE,
                     title="Blur",
                     width=100)
    checkbox = CheckboxGroup(labels=["TFIDF"], active=[])
    radio_button_group = RadioGroup(labels=["Viridis256", "Greys256"],
                                    active=0)
    button = Button(label="Metadata", button_type="default", width=100)
    multi_select = MultiSelect(title=output_filename,
                               value=[default_term],
                               options=options,
                               width=100,
                               height=300)

    # WIDGETS CODE FOR CALLBACK
    code_callback_slider1 = return_code('slider1')
    code_callback_slider2 = return_code('slider2')
    code_callback_checkbox = return_code('checkbox')
    code_callback_rbg = return_code('rbg')
    code_callback_button = return_code('button')
    code_callback_ms = return_code('multi_select')

    # WIDGETS CALLBACK
    callback_slider1 = CustomJS(args={
        'source': source,
        'mapper': mapper
    },
                                code=code_callback_slider1)
    callback_slider2 = CustomJS(args={
        'source': source,
        'mapper': mapper,
        'slider1': slider1,
        'multi_select': multi_select,
        'checkbox': checkbox,
        'term_to_source': term_to_source,
        'step_size': BLUR_STEP_SIZE
    },
                                code=code_callback_slider2)
    callback_checkbox = CustomJS(args={
        'source': source,
        'term_to_source': term_to_source,
        'multi_select': multi_select,
        'step_size': BLUR_STEP_SIZE,
        'slider1': slider1,
        'slider2': slider2,
        'mapper': mapper
    },
                                 code=code_callback_checkbox)
    callback_radio_button_group = CustomJS(args={
        'p': p,
        'mapper': mapper,
        'Viridis256': Viridis256,
        'Greys256': Greys256
    },
                                           code=code_callback_rbg)
    callback_button = CustomJS(args={
        'term_to_metadata': term_to_metadata,
        'multi_select': multi_select
    },
                               code=code_callback_button)
    callback_ms = CustomJS(args={
        'source': source,
        'term_to_source': term_to_source,
        'checkbox': checkbox,
        'metadata': metadata,
        'step_size': BLUR_STEP_SIZE,
        'slider2': slider2,
        'slider1': slider1,
        'p': p,
        'mapper': mapper
    },
                           code=code_callback_ms)

    # # WIDGETS INTERACTION
    slider1.js_on_change('value', callback_slider1)
    slider2.js_on_change('value', callback_slider2)
    checkbox.js_on_change('active', callback_checkbox)
    radio_button_group.js_on_change('active', callback_radio_button_group)
    button.js_on_event(events.ButtonClick, callback_button)
    multi_select.js_on_change("value", callback_ms)

    # LAYOUT
    layout = row(
        multi_select, p,
        column(slider1, slider2, checkbox, radio_button_group, button))

    show(layout)
示例#14
0
                      button_type="primary")
shuffle_button = Button(label="Shake!", button_type="primary")
start_button = Button(label="Start timer", button_type="success")
stop_button = Button(label="Stop timer", button_type="danger")
timer = Div(
    text=f"""Timer: <br> 0:00""",
    style={
        "font-size": "400%",
        "color": "black",
        "text-align": "center"
    },
)
show_words_button = Button(label="Show all words?", button_type="danger")
show_words_options = RadioButtonGroup(labels=["Alphabetical", "By Length"],
                                      active=0)
word_select = MultiSelect(value=[], options=[], height=500, size=10)
word_count = Div(text="")

# make word histogram
hist_src = ColumnDataSource(dict(hist=[], left=[], right=[]))
# pdf_src = ColumnDataSource(dict(x=[], pdf=[]))
p = figure(
    plot_height=200,
    plot_width=350,
    title="Word Count",
    tools="",
    background_fill_color="#fafafa",
    x_axis_label="# of Words",
    y_axis_label="# of Games",
)
p.quad(
示例#15
0
        columns={"New Flag_x": "New Flag", "Comments_x": "Comments"}
    ).drop(columns=["New Flag_y", "Comments_y"])

# intialize widgets
save_button = Button(label="Save flagged data", button_type="success")
parameter = Select(title="Parameter", options=["CTDSAL", "CTDTMP"], value="CTDSAL")
ref_param = Select(title="Reference", options=["SALNTY"], value="SALNTY")
# ref_param.options = ["foo","bar"]  # can dynamically change dropdowns
station = Select(title="Station", options=ssscc_list, value=ssscc_list[0])
# explanation of flags:
# https://cchdo.github.io/hdo-assets/documentation/manuals/pdf/90_1/chap4.pdf
flag_list = MultiSelect(
    title="Plot data flagged as:",
    value=["1", "2", "3"],
    options=[
        ("1", "1 [Uncalibrated]"),
        ("2", "2 [Acceptable]"),
        ("3", "3 [Questionable]"),
        ("4", "4 [Bad]"),
    ],
)
# returns list of select options, e.g., ['2'] or ['1','2']
flag_input = Select(
    title="Flag:",
    options=[
        ("1", "1 [Uncalibrated]"),
        ("2", "2 [Acceptable]"),
        ("3", "3 [Questionable]"),
        ("4", "4 [Bad]"),
    ],
    value="3",
)
def create():
    doc = curdoc()
    det_data = {}
    cami_meta = {}

    def proposal_textinput_callback(_attr, _old, new):
        nonlocal cami_meta
        proposal = new.strip()
        for zebra_proposals_path in pyzebra.ZEBRA_PROPOSALS_PATHS:
            proposal_path = os.path.join(zebra_proposals_path, proposal)
            if os.path.isdir(proposal_path):
                # found it
                break
        else:
            raise ValueError(f"Can not find data for proposal '{proposal}'.")

        file_list = []
        for file in os.listdir(proposal_path):
            if file.endswith(".hdf"):
                file_list.append((os.path.join(proposal_path, file), file))
        file_select.options = file_list

        cami_meta = {}

    proposal_textinput = TextInput(title="Proposal number:", width=210)
    proposal_textinput.on_change("value", proposal_textinput_callback)

    def upload_button_callback(_attr, _old, new):
        nonlocal cami_meta
        with io.StringIO(base64.b64decode(new).decode()) as file:
            cami_meta = pyzebra.parse_h5meta(file)
            file_list = cami_meta["filelist"]
            file_select.options = [(entry, os.path.basename(entry))
                                   for entry in file_list]

    upload_div = Div(text="or upload .cami file:", margin=(5, 5, 0, 5))
    upload_button = FileInput(accept=".cami", width=200)
    upload_button.on_change("value", upload_button_callback)

    def update_image(index=None):
        if index is None:
            index = index_spinner.value

        current_image = det_data["data"][index]
        proj_v_line_source.data.update(x=np.arange(0, IMAGE_W) + 0.5,
                                       y=np.mean(current_image, axis=0))
        proj_h_line_source.data.update(x=np.mean(current_image, axis=1),
                                       y=np.arange(0, IMAGE_H) + 0.5)

        image_source.data.update(
            h=[np.zeros((1, 1))],
            k=[np.zeros((1, 1))],
            l=[np.zeros((1, 1))],
        )
        image_source.data.update(image=[current_image])

        if main_auto_checkbox.active:
            im_min = np.min(current_image)
            im_max = np.max(current_image)

            display_min_spinner.value = im_min
            display_max_spinner.value = im_max

            image_glyph.color_mapper.low = im_min
            image_glyph.color_mapper.high = im_max

        if "mf" in det_data:
            metadata_table_source.data.update(mf=[det_data["mf"][index]])
        else:
            metadata_table_source.data.update(mf=[None])

        if "temp" in det_data:
            metadata_table_source.data.update(temp=[det_data["temp"][index]])
        else:
            metadata_table_source.data.update(temp=[None])

        gamma, nu = calculate_pol(det_data, index)
        omega = np.ones((IMAGE_H, IMAGE_W)) * det_data["omega"][index]
        image_source.data.update(gamma=[gamma], nu=[nu], omega=[omega])

    def update_overview_plot():
        h5_data = det_data["data"]
        n_im, n_y, n_x = h5_data.shape
        overview_x = np.mean(h5_data, axis=1)
        overview_y = np.mean(h5_data, axis=2)

        overview_plot_x_image_source.data.update(image=[overview_x],
                                                 dw=[n_x],
                                                 dh=[n_im])
        overview_plot_y_image_source.data.update(image=[overview_y],
                                                 dw=[n_y],
                                                 dh=[n_im])

        if proj_auto_checkbox.active:
            im_min = min(np.min(overview_x), np.min(overview_y))
            im_max = max(np.max(overview_x), np.max(overview_y))

            proj_display_min_spinner.value = im_min
            proj_display_max_spinner.value = im_max

            overview_plot_x_image_glyph.color_mapper.low = im_min
            overview_plot_y_image_glyph.color_mapper.low = im_min
            overview_plot_x_image_glyph.color_mapper.high = im_max
            overview_plot_y_image_glyph.color_mapper.high = im_max

        frame_range.start = 0
        frame_range.end = n_im
        frame_range.reset_start = 0
        frame_range.reset_end = n_im
        frame_range.bounds = (0, n_im)

        scan_motor = det_data["scan_motor"]
        overview_plot_y.axis[1].axis_label = f"Scanning motor, {scan_motor}"

        var = det_data[scan_motor]
        var_start = var[0]
        var_end = var[-1] + (var[-1] - var[0]) / (n_im - 1)

        scanning_motor_range.start = var_start
        scanning_motor_range.end = var_end
        scanning_motor_range.reset_start = var_start
        scanning_motor_range.reset_end = var_end
        # handle both, ascending and descending sequences
        scanning_motor_range.bounds = (min(var_start,
                                           var_end), max(var_start, var_end))

    def file_select_callback(_attr, old, new):
        nonlocal det_data
        if not new:
            # skip empty selections
            return

        # Avoid selection of multiple indicies (via Shift+Click or Ctrl+Click)
        if len(new) > 1:
            # drop selection to the previous one
            file_select.value = old
            return

        if len(old) > 1:
            # skip unnecessary update caused by selection drop
            return

        det_data = pyzebra.read_detector_data(new[0])

        if cami_meta and "crystal" in cami_meta:
            det_data["ub"] = cami_meta["crystal"]["UB"]

        index_spinner.value = 0
        index_spinner.high = det_data["data"].shape[0] - 1
        index_slider.end = det_data["data"].shape[0] - 1

        zebra_mode = det_data["zebra_mode"]
        if zebra_mode == "nb":
            metadata_table_source.data.update(geom=["normal beam"])
        else:  # zebra_mode == "bi"
            metadata_table_source.data.update(geom=["bisecting"])

        update_image(0)
        update_overview_plot()

    file_select = MultiSelect(title="Available .hdf files:",
                              width=210,
                              height=250)
    file_select.on_change("value", file_select_callback)

    def index_callback(_attr, _old, new):
        update_image(new)

    index_slider = Slider(value=0, start=0, end=1, show_value=False, width=400)

    index_spinner = Spinner(title="Image index:", value=0, low=0, width=100)
    index_spinner.on_change("value", index_callback)

    index_slider.js_link("value_throttled", index_spinner, "value")
    index_spinner.js_link("value", index_slider, "value")

    plot = Plot(
        x_range=Range1d(0, IMAGE_W, bounds=(0, IMAGE_W)),
        y_range=Range1d(0, IMAGE_H, bounds=(0, IMAGE_H)),
        plot_height=IMAGE_PLOT_H,
        plot_width=IMAGE_PLOT_W,
        toolbar_location="left",
    )

    # ---- tools
    plot.toolbar.logo = None

    # ---- axes
    plot.add_layout(LinearAxis(), place="above")
    plot.add_layout(LinearAxis(major_label_orientation="vertical"),
                    place="right")

    # ---- grid lines
    plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- rgba image glyph
    image_source = ColumnDataSource(
        dict(
            image=[np.zeros((IMAGE_H, IMAGE_W), dtype="float32")],
            h=[np.zeros((1, 1))],
            k=[np.zeros((1, 1))],
            l=[np.zeros((1, 1))],
            gamma=[np.zeros((1, 1))],
            nu=[np.zeros((1, 1))],
            omega=[np.zeros((1, 1))],
            x=[0],
            y=[0],
            dw=[IMAGE_W],
            dh=[IMAGE_H],
        ))

    h_glyph = Image(image="h", x="x", y="y", dw="dw", dh="dh", global_alpha=0)
    k_glyph = Image(image="k", x="x", y="y", dw="dw", dh="dh", global_alpha=0)
    l_glyph = Image(image="l", x="x", y="y", dw="dw", dh="dh", global_alpha=0)
    gamma_glyph = Image(image="gamma",
                        x="x",
                        y="y",
                        dw="dw",
                        dh="dh",
                        global_alpha=0)
    nu_glyph = Image(image="nu",
                     x="x",
                     y="y",
                     dw="dw",
                     dh="dh",
                     global_alpha=0)
    omega_glyph = Image(image="omega",
                        x="x",
                        y="y",
                        dw="dw",
                        dh="dh",
                        global_alpha=0)

    plot.add_glyph(image_source, h_glyph)
    plot.add_glyph(image_source, k_glyph)
    plot.add_glyph(image_source, l_glyph)
    plot.add_glyph(image_source, gamma_glyph)
    plot.add_glyph(image_source, nu_glyph)
    plot.add_glyph(image_source, omega_glyph)

    image_glyph = Image(image="image", x="x", y="y", dw="dw", dh="dh")
    plot.add_glyph(image_source, image_glyph, name="image_glyph")

    # ---- projections
    proj_v = Plot(
        x_range=plot.x_range,
        y_range=DataRange1d(),
        plot_height=150,
        plot_width=IMAGE_PLOT_W,
        toolbar_location=None,
    )

    proj_v.add_layout(LinearAxis(major_label_orientation="vertical"),
                      place="right")
    proj_v.add_layout(LinearAxis(major_label_text_font_size="0pt"),
                      place="below")

    proj_v.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    proj_v.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    proj_v_line_source = ColumnDataSource(dict(x=[], y=[]))
    proj_v.add_glyph(proj_v_line_source,
                     Line(x="x", y="y", line_color="steelblue"))

    proj_h = Plot(
        x_range=DataRange1d(),
        y_range=plot.y_range,
        plot_height=IMAGE_PLOT_H,
        plot_width=150,
        toolbar_location=None,
    )

    proj_h.add_layout(LinearAxis(), place="above")
    proj_h.add_layout(LinearAxis(major_label_text_font_size="0pt"),
                      place="left")

    proj_h.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    proj_h.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    proj_h_line_source = ColumnDataSource(dict(x=[], y=[]))
    proj_h.add_glyph(proj_h_line_source,
                     Line(x="x", y="y", line_color="steelblue"))

    # add tools
    hovertool = HoverTool(tooltips=[
        ("intensity", "@image"),
        ("gamma", "@gamma"),
        ("nu", "@nu"),
        ("omega", "@omega"),
        ("h", "@h"),
        ("k", "@k"),
        ("l", "@l"),
    ])

    box_edit_source = ColumnDataSource(dict(x=[], y=[], width=[], height=[]))
    box_edit_glyph = Rect(x="x",
                          y="y",
                          width="width",
                          height="height",
                          fill_alpha=0,
                          line_color="red")
    box_edit_renderer = plot.add_glyph(box_edit_source, box_edit_glyph)
    boxedittool = BoxEditTool(renderers=[box_edit_renderer], num_objects=1)

    def box_edit_callback(_attr, _old, new):
        if new["x"]:
            h5_data = det_data["data"]
            x_val = np.arange(h5_data.shape[0])
            left = int(np.floor(new["x"][0]))
            right = int(np.ceil(new["x"][0] + new["width"][0]))
            bottom = int(np.floor(new["y"][0]))
            top = int(np.ceil(new["y"][0] + new["height"][0]))
            y_val = np.sum(h5_data[:, bottom:top, left:right], axis=(1, 2))
        else:
            x_val = []
            y_val = []

        roi_avg_plot_line_source.data.update(x=x_val, y=y_val)

    box_edit_source.on_change("data", box_edit_callback)

    wheelzoomtool = WheelZoomTool(maintain_focus=False)
    plot.add_tools(
        PanTool(),
        BoxZoomTool(),
        wheelzoomtool,
        ResetTool(),
        hovertool,
        boxedittool,
    )
    plot.toolbar.active_scroll = wheelzoomtool

    # shared frame ranges
    frame_range = Range1d(0, 1, bounds=(0, 1))
    scanning_motor_range = Range1d(0, 1, bounds=(0, 1))

    det_x_range = Range1d(0, IMAGE_W, bounds=(0, IMAGE_W))
    overview_plot_x = Plot(
        title=Title(text="Projections on X-axis"),
        x_range=det_x_range,
        y_range=frame_range,
        extra_y_ranges={"scanning_motor": scanning_motor_range},
        plot_height=400,
        plot_width=IMAGE_PLOT_W - 3,
    )

    # ---- tools
    wheelzoomtool = WheelZoomTool(maintain_focus=False)
    overview_plot_x.toolbar.logo = None
    overview_plot_x.add_tools(
        PanTool(),
        BoxZoomTool(),
        wheelzoomtool,
        ResetTool(),
    )
    overview_plot_x.toolbar.active_scroll = wheelzoomtool

    # ---- axes
    overview_plot_x.add_layout(LinearAxis(axis_label="Coordinate X, pix"),
                               place="below")
    overview_plot_x.add_layout(LinearAxis(axis_label="Frame",
                                          major_label_orientation="vertical"),
                               place="left")

    # ---- grid lines
    overview_plot_x.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    overview_plot_x.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- rgba image glyph
    overview_plot_x_image_source = ColumnDataSource(
        dict(image=[np.zeros((1, 1), dtype="float32")],
             x=[0],
             y=[0],
             dw=[IMAGE_W],
             dh=[1]))

    overview_plot_x_image_glyph = Image(image="image",
                                        x="x",
                                        y="y",
                                        dw="dw",
                                        dh="dh")
    overview_plot_x.add_glyph(overview_plot_x_image_source,
                              overview_plot_x_image_glyph,
                              name="image_glyph")

    det_y_range = Range1d(0, IMAGE_H, bounds=(0, IMAGE_H))
    overview_plot_y = Plot(
        title=Title(text="Projections on Y-axis"),
        x_range=det_y_range,
        y_range=frame_range,
        extra_y_ranges={"scanning_motor": scanning_motor_range},
        plot_height=400,
        plot_width=IMAGE_PLOT_H + 22,
    )

    # ---- tools
    wheelzoomtool = WheelZoomTool(maintain_focus=False)
    overview_plot_y.toolbar.logo = None
    overview_plot_y.add_tools(
        PanTool(),
        BoxZoomTool(),
        wheelzoomtool,
        ResetTool(),
    )
    overview_plot_y.toolbar.active_scroll = wheelzoomtool

    # ---- axes
    overview_plot_y.add_layout(LinearAxis(axis_label="Coordinate Y, pix"),
                               place="below")
    overview_plot_y.add_layout(
        LinearAxis(
            y_range_name="scanning_motor",
            axis_label="Scanning motor",
            major_label_orientation="vertical",
        ),
        place="right",
    )

    # ---- grid lines
    overview_plot_y.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    overview_plot_y.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    # ---- rgba image glyph
    overview_plot_y_image_source = ColumnDataSource(
        dict(image=[np.zeros((1, 1), dtype="float32")],
             x=[0],
             y=[0],
             dw=[IMAGE_H],
             dh=[1]))

    overview_plot_y_image_glyph = Image(image="image",
                                        x="x",
                                        y="y",
                                        dw="dw",
                                        dh="dh")
    overview_plot_y.add_glyph(overview_plot_y_image_source,
                              overview_plot_y_image_glyph,
                              name="image_glyph")

    roi_avg_plot = Plot(
        x_range=DataRange1d(),
        y_range=DataRange1d(),
        plot_height=150,
        plot_width=IMAGE_PLOT_W,
        toolbar_location="left",
    )

    # ---- tools
    roi_avg_plot.toolbar.logo = None

    # ---- axes
    roi_avg_plot.add_layout(LinearAxis(), place="below")
    roi_avg_plot.add_layout(LinearAxis(major_label_orientation="vertical"),
                            place="left")

    # ---- grid lines
    roi_avg_plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
    roi_avg_plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))

    roi_avg_plot_line_source = ColumnDataSource(dict(x=[], y=[]))
    roi_avg_plot.add_glyph(roi_avg_plot_line_source,
                           Line(x="x", y="y", line_color="steelblue"))

    cmap_dict = {
        "gray": Greys256,
        "gray_reversed": Greys256[::-1],
        "plasma": Plasma256,
        "cividis": Cividis256,
    }

    def colormap_callback(_attr, _old, new):
        image_glyph.color_mapper = LinearColorMapper(palette=cmap_dict[new])
        overview_plot_x_image_glyph.color_mapper = LinearColorMapper(
            palette=cmap_dict[new])
        overview_plot_y_image_glyph.color_mapper = LinearColorMapper(
            palette=cmap_dict[new])

    colormap = Select(title="Colormap:",
                      options=list(cmap_dict.keys()),
                      width=210)
    colormap.on_change("value", colormap_callback)
    colormap.value = "plasma"

    STEP = 1

    def main_auto_checkbox_callback(state):
        if state:
            display_min_spinner.disabled = True
            display_max_spinner.disabled = True
        else:
            display_min_spinner.disabled = False
            display_max_spinner.disabled = False

        update_image()

    main_auto_checkbox = CheckboxGroup(labels=["Main Auto Range"],
                                       active=[0],
                                       width=145,
                                       margin=[10, 5, 0, 5])
    main_auto_checkbox.on_click(main_auto_checkbox_callback)

    def display_max_spinner_callback(_attr, _old_value, new_value):
        display_min_spinner.high = new_value - STEP
        image_glyph.color_mapper.high = new_value

    display_max_spinner = Spinner(
        low=0 + STEP,
        value=1,
        step=STEP,
        disabled=bool(main_auto_checkbox.active),
        width=100,
        height=31,
    )
    display_max_spinner.on_change("value", display_max_spinner_callback)

    def display_min_spinner_callback(_attr, _old_value, new_value):
        display_max_spinner.low = new_value + STEP
        image_glyph.color_mapper.low = new_value

    display_min_spinner = Spinner(
        low=0,
        high=1 - STEP,
        value=0,
        step=STEP,
        disabled=bool(main_auto_checkbox.active),
        width=100,
        height=31,
    )
    display_min_spinner.on_change("value", display_min_spinner_callback)

    PROJ_STEP = 0.1

    def proj_auto_checkbox_callback(state):
        if state:
            proj_display_min_spinner.disabled = True
            proj_display_max_spinner.disabled = True
        else:
            proj_display_min_spinner.disabled = False
            proj_display_max_spinner.disabled = False

        update_overview_plot()

    proj_auto_checkbox = CheckboxGroup(labels=["Projections Auto Range"],
                                       active=[0],
                                       width=145,
                                       margin=[10, 5, 0, 5])
    proj_auto_checkbox.on_click(proj_auto_checkbox_callback)

    def proj_display_max_spinner_callback(_attr, _old_value, new_value):
        proj_display_min_spinner.high = new_value - PROJ_STEP
        overview_plot_x_image_glyph.color_mapper.high = new_value
        overview_plot_y_image_glyph.color_mapper.high = new_value

    proj_display_max_spinner = Spinner(
        low=0 + PROJ_STEP,
        value=1,
        step=PROJ_STEP,
        disabled=bool(proj_auto_checkbox.active),
        width=100,
        height=31,
    )
    proj_display_max_spinner.on_change("value",
                                       proj_display_max_spinner_callback)

    def proj_display_min_spinner_callback(_attr, _old_value, new_value):
        proj_display_max_spinner.low = new_value + PROJ_STEP
        overview_plot_x_image_glyph.color_mapper.low = new_value
        overview_plot_y_image_glyph.color_mapper.low = new_value

    proj_display_min_spinner = Spinner(
        low=0,
        high=1 - PROJ_STEP,
        value=0,
        step=PROJ_STEP,
        disabled=bool(proj_auto_checkbox.active),
        width=100,
        height=31,
    )
    proj_display_min_spinner.on_change("value",
                                       proj_display_min_spinner_callback)

    def hkl_button_callback():
        index = index_spinner.value
        h, k, l = calculate_hkl(det_data, index)
        image_source.data.update(h=[h], k=[k], l=[l])

    hkl_button = Button(label="Calculate hkl (slow)", width=210)
    hkl_button.on_click(hkl_button_callback)

    def events_list_callback(_attr, _old, new):
        doc.events_list_spind.value = new

    events_list = TextAreaInput(rows=7, width=830)
    events_list.on_change("value", events_list_callback)
    doc.events_list_hdf_viewer = events_list

    def add_event_button_callback():
        diff_vec = []
        p0 = [1.0, 0.0, 1.0]
        maxfev = 100000

        wave = det_data["wave"]
        ddist = det_data["ddist"]

        gamma = det_data["gamma"][0]
        omega = det_data["omega"][0]
        nu = det_data["nu"][0]
        chi = det_data["chi"][0]
        phi = det_data["phi"][0]

        scan_motor = det_data["scan_motor"]
        var_angle = det_data[scan_motor]

        x0 = int(np.floor(det_x_range.start))
        xN = int(np.ceil(det_x_range.end))
        y0 = int(np.floor(det_y_range.start))
        yN = int(np.ceil(det_y_range.end))
        fr0 = int(np.floor(frame_range.start))
        frN = int(np.ceil(frame_range.end))
        data_roi = det_data["data"][fr0:frN, y0:yN, x0:xN]

        cnts = np.sum(data_roi, axis=(1, 2))
        coeff, _ = curve_fit(gauss,
                             range(len(cnts)),
                             cnts,
                             p0=p0,
                             maxfev=maxfev)

        m = cnts.mean()
        sd = cnts.std()
        snr_cnts = np.where(sd == 0, 0, m / sd)

        frC = fr0 + coeff[1]
        var_F = var_angle[math.floor(frC)]
        var_C = var_angle[math.ceil(frC)]
        frStep = frC - math.floor(frC)
        var_step = var_C - var_F
        var_p = var_F + var_step * frStep

        if scan_motor == "gamma":
            gamma = var_p
        elif scan_motor == "omega":
            omega = var_p
        elif scan_motor == "nu":
            nu = var_p
        elif scan_motor == "chi":
            chi = var_p
        elif scan_motor == "phi":
            phi = var_p

        intensity = coeff[1] * abs(
            coeff[2] * var_step) * math.sqrt(2) * math.sqrt(np.pi)

        projX = np.sum(data_roi, axis=(0, 1))
        coeff, _ = curve_fit(gauss,
                             range(len(projX)),
                             projX,
                             p0=p0,
                             maxfev=maxfev)
        x_pos = x0 + coeff[1]

        projY = np.sum(data_roi, axis=(0, 2))
        coeff, _ = curve_fit(gauss,
                             range(len(projY)),
                             projY,
                             p0=p0,
                             maxfev=maxfev)
        y_pos = y0 + coeff[1]

        ga, nu = pyzebra.det2pol(ddist, gamma, nu, x_pos, y_pos)
        diff_vector = pyzebra.z1frmd(wave, ga, omega, chi, phi, nu)
        d_spacing = float(pyzebra.dandth(wave, diff_vector)[0])
        diff_vector = diff_vector.flatten() * 1e10
        dv1, dv2, dv3 = diff_vector

        diff_vec.append(diff_vector)

        if events_list.value and not events_list.value.endswith("\n"):
            events_list.value = events_list.value + "\n"

        events_list.value = (
            events_list.value +
            f"{x_pos} {y_pos} {intensity} {snr_cnts} {dv1} {dv2} {dv3} {d_spacing}"
        )

    add_event_button = Button(label="Add spind event")
    add_event_button.on_click(add_event_button_callback)

    metadata_table_source = ColumnDataSource(
        dict(geom=[""], temp=[None], mf=[None]))
    num_formatter = NumberFormatter(format="0.00", nan_format="")
    metadata_table = DataTable(
        source=metadata_table_source,
        columns=[
            TableColumn(field="geom", title="Geometry", width=100),
            TableColumn(field="temp",
                        title="Temperature",
                        formatter=num_formatter,
                        width=100),
            TableColumn(field="mf",
                        title="Magnetic Field",
                        formatter=num_formatter,
                        width=100),
        ],
        width=300,
        height=50,
        autosize_mode="none",
        index_position=None,
    )

    # Final layout
    import_layout = column(proposal_textinput, upload_div, upload_button,
                           file_select)
    layout_image = column(
        gridplot([[proj_v, None], [plot, proj_h]], merge_tools=False))
    colormap_layout = column(
        colormap,
        main_auto_checkbox,
        row(display_min_spinner, display_max_spinner),
        proj_auto_checkbox,
        row(proj_display_min_spinner, proj_display_max_spinner),
    )

    layout_controls = column(
        row(metadata_table, index_spinner,
            column(Spacer(height=25), index_slider)),
        row(add_event_button, hkl_button),
        row(events_list),
    )

    layout_overview = column(
        gridplot(
            [[overview_plot_x, overview_plot_y]],
            toolbar_options=dict(logo=None),
            merge_tools=True,
            toolbar_location="left",
        ), )

    tab_layout = row(
        column(import_layout, colormap_layout),
        column(layout_overview, layout_controls),
        column(roi_avg_plot, layout_image),
    )

    return Panel(child=tab_layout, title="hdf viewer")
示例#17
0
             u'primary visual cortex (striate cortex, area V1/17)',
             u'striatum',
             u'primary motor cortex (area M1, area 4)',
             u'posteroventral (inferior) parietal cortex',
             u'primary somatosensory cortex (area S1, areas 3,1,2)',
             u'cerebellum',
             u'cerebellar cortex',
             u'mediodorsal nucleus of thalamus']

df1, df2 = get_dataframes(gene, structures)
source1 = ColumnDataSource(df1)
source2 = ColumnDataSource(df2)
source3, source4 = get_boxplot_data(df1)

age_plot = expression_by_age(source1, source2)
structure_plot = structure_boxplot(source3, source4)

source5 = get_heatmap_data(genes1)
heatmap = plot_heatmap(source5, genes1)

plot = vplot(structure_plot,age_plot, heatmap)

multi_select = MultiSelect(title="Brain Regions:", value=structures,
                           options=structures)
multi_select.on_change('value', update_plot)

select = Select(title="Gene:", value=genes[2], options=genes)
select.on_change('value', update_plot)

curdoc().add_root(vplot(select, plot, multi_select))
def bkapp(doc):
    
### Functions ###

    # functions for user dialogs

    def open_file(ftype):
        root = Tk()
        root.withdraw()
        file = askopenfilename(filetypes=ftype,
                               title='Open File',
                               initialdir=os.getcwd())
        root.destroy()
        return file
    
    def choose_directory():
        root = Tk()
        root.withdraw()
        out_dir = askdirectory()
        root.destroy()
        return out_dir

    def write_output_directory(output_dir):
        root = Tk()
        root.withdraw()
        makeDir = askquestion('Make Directory','Output directory not set. Make directory: '
            +output_dir+'? If not, you\'ll be prompted to change directories.',icon = 'warning')
        root.destroy()
        return makeDir

    def overwrite_file():
        root = Tk()
        root.withdraw()
        overwrite = askquestion('Overwrite File','File already exits. Do you want to overwrite?',icon = 'warning')
        root.destroy()
        return overwrite

    def update_filename():
        filetype = [("Video files", "*.mp4")]
        fname = open_file(filetype)
        if fname:
            #print('Successfully loaded file: '+fname)
            load_data(filename=fname)         

    def change_directory():
        out_dir = choose_directory()
        if out_dir:
            source.data["output_dir"] = [out_dir]
            outDir.text = out_dir
        return out_dir

    # load data from file

    def load_data(filename):
        vidcap = cv2.VideoCapture(filename)
        success,frame = vidcap.read()
        img_tmp,_,__ = cv2.split(frame)
        h,w = np.shape(img_tmp)
        img = np.flipud(img_tmp)
        radio_button_gp.active = 0
        fname = os.path.split(filename)[1]
        input_dir = os.path.split(filename)[0]
        if source.data['output_dir'][0]=='':
            output_dir = os.path.join(input_dir,fname.split('.')[0])
        else:
            output_dir = source.data['output_dir'][0]
        if not os.path.isdir(output_dir):
            makeDir = write_output_directory(output_dir)
            if makeDir=='yes':
                os.mkdir(output_dir)
            else:
                output_dir = change_directory()
        if output_dir:
            source.data = dict(image_orig=[img], image=[img], bin_img=[0],
                x=[0], y=[0], dw=[w], dh=[h], num_contours=[0], roi_coords=[0], 
                img_name=[fname], input_dir=[input_dir], output_dir=[output_dir])
            curr_img = p.select_one({'name':'image'})
            if curr_img:
                p.renderers.remove(curr_img)
            p.image(source=source, image='image', x='x', y='y', dw='dw', dh='dh', color_mapper=cmap,level='image',name='image')      
            p.plot_height=int(h/2)
            p.plot_width=int(w/2)
            #p.add_tools(HoverTool(tooltips=IMG_TOOLTIPS))
            inFile.text = fname
            outDir.text = output_dir
        else:
            print('Cancelled. To continue please set output directory.{:<100}'.format(' '),end="\r")

    # resetting sources for new data or new filters/contours

    def remove_plot():
        source.data["num_contours"]=[0]
        contours_found.text = 'Droplets Detected: 0'
        source_contours.data = dict(xs=[], ys=[])
        source_label.data = dict(x=[], y=[], label=[])

    # apply threshold filter and display binary image

    def apply_filter():
        if source.data['input_dir'][0] == '':
            print('No image loaded! Load image first.{:<100}'.format(' '),end="\r")
        else:
            img = np.squeeze(source.data['image_orig'])
            # remove contours if present
            if source.data["num_contours"]!=[0]:
                remove_plot()
            if radio_button_gp.active == 1:
                thresh = filters.threshold_otsu(img)
                binary = img > thresh
                bin_img = binary*255
                source.data["bin_img"] = [bin_img]
            elif radio_button_gp.active == 2:
                thresh = filters.threshold_isodata(img)
                binary = img > thresh
                bin_img = binary*255
                source.data["bin_img"] = [bin_img]
            elif radio_button_gp.active == 3:
                thresh = filters.threshold_mean(img)
                binary = img > thresh
                bin_img = binary*255
                source.data["bin_img"] = [bin_img]
            elif radio_button_gp.active == 4:
                thresh = filters.threshold_li(img)
                binary = img > thresh
                bin_img = binary*255
                source.data["bin_img"] = [bin_img]
            elif radio_button_gp.active == 5:
                thresh = filters.threshold_yen(img)
                binary = img > thresh
                bin_img = binary*255
                source.data["bin_img"] = [bin_img]
            elif radio_button_gp.active == 6:
                off = offset_spinner.value
                block_size = block_spinner.value
                thresh = filters.threshold_local(img,block_size,offset=off)
                binary = img > thresh
                bin_img = binary*255
                source.data["bin_img"] = [bin_img]
            else:
                bin_img = img
            source.data['image'] = [bin_img]

    # image functions for adjusting the binary image
    
    def close_img():
        if source.data["num_contours"]!=[0]:
            remove_plot()
        if radio_button_gp.active == 0:
            print("Please Select Filter for Threshold{:<100}".format(' '),end="\r")
        else:
            source.data["image"] = source.data["bin_img"]
            img = np.squeeze(source.data['bin_img'])
            closed_img = binary_closing(255-img)*255
            source.data['image'] = [255-closed_img]
            source.data['bin_img'] = [255-closed_img]

    def dilate_img():
        if source.data["num_contours"]!=[0]:
            remove_plot()
        if radio_button_gp.active == 0:
            print("Please Select Filter for Threshold{:<100}".format(' '),end="\r")
        else:
            img = np.squeeze(source.data['bin_img'])
            dilated_img = binary_dilation(255-img)*255
            source.data['image'] = [255-dilated_img]
            source.data['bin_img'] = [255-dilated_img]

    def erode_img():
        if source.data["num_contours"]!=[0]:
            remove_plot()
        if radio_button_gp.active == 0:
            print("Please Select Filter for Threshold{:<100}".format(' '),end="\r")
        else:
            img = np.squeeze(source.data['bin_img'])
            eroded_img = binary_erosion(255-img)*255
            source.data['image'] = [255-eroded_img] 
            source.data['bin_img'] = [255-eroded_img]  

    # the function for identifying closed contours in the image

    def find_contours(level=0.8):
        min_drop_size = contour_rng_slider.value[0]
        max_drop_size = contour_rng_slider.value[1]
        min_dim = 20
        max_dim = 200
        if radio_button_gp.active == 0:
            print("Please Select Filter for Threshold{:<100}".format(' '),end="\r")
        elif source.data['input_dir'][0] == '':
            print('No image loaded! Load image first.{:<100}'.format(' '),end="\r")
        else:
            img = np.squeeze(source.data['bin_img'])
            h,w = np.shape(img)        
            contours = measure.find_contours(img, level)
            length_cnt_x = [cnt[:,1] for cnt in contours if np.shape(cnt)[0] < max_drop_size 
                             and np.shape(cnt)[0] > min_drop_size]
            length_cnt_y = [cnt[:,0] for cnt in contours if np.shape(cnt)[0] < max_drop_size 
                             and np.shape(cnt)[0] > min_drop_size]
            matched_cnt_x = []
            matched_cnt_y = []
            roi_coords = []
            label_text = []
            label_y = np.array([])
            count=0
            for i in range(len(length_cnt_x)):
                cnt_x = length_cnt_x[i]
                cnt_y = length_cnt_y[i]
                width = np.max(cnt_x)-np.min(cnt_x)
                height = np.max(cnt_y)-np.min(cnt_y)
                if width>min_dim and width<max_dim and height>min_dim and height<max_dim:
                    matched_cnt_x.append(cnt_x)
                    matched_cnt_y.append(cnt_y)
                    roi_coords.append([round(width),round(height),round(np.min(cnt_x)),round(h-np.max(cnt_y))])
                    label_text.append(str(int(count)+1))
                    label_y = np.append(label_y,np.max(cnt_y))
                    count+=1
            curr_contours = p.select_one({'name':'overlay'})
            if curr_contours:
                p.renderers.remove(curr_contours)
            #if source.data["num_contours"]==[0]:
                #remove_plot()
                #p.image(source=source, image='image_orig', x=0, y=0, dw=w, dh=h, color_mapper=cmap, name='overlay',level='underlay')      
            source.data["image"] = source.data["image_orig"]
            source.data["num_contours"] = [count]
            #source.data["cnt_x"] = [matched_cnt_x]
            #source.data["cnt_y"] = [matched_cnt_y]
            source.data["roi_coords"] = [roi_coords]
            source_contours.data = dict(xs=matched_cnt_x, ys=matched_cnt_y)
            p.multi_line(xs='xs',ys='ys',source=source_contours, color=(255,127,14),line_width=2, name="contours",level='glyph')
            if len(np.array(roi_coords).shape)<2:
                if len(np.array(roi_coords)) <4:
                    print('No contours found. Try adjusting parameters or filter for thresholding.{:<100}'.format(' '),end="\r")
                    source_label.data = dict(x=[],y=[],label=[])
                else:
                    source_label.data = dict(x=np.array(roi_coords)[2], y=label_y, label=label_text)
            else:
                source_label.data = dict(x=np.array(roi_coords)[:,2], y=label_y, label=label_text)
            contours_found.text = 'Droplets Detected: '+str(len(matched_cnt_x))

    # write the contours and parameters to files

    def export_ROIs():
        if source.data["num_contours"]==[0]:
            print("No Contours Found! Find contours first.{:<100}".format(' '),end="\r")
        else:
            hdr = 'threshold filter,contour min,contour max'
            thresh_filter = radio_button_gp.active
            cnt_min = contour_rng_slider.value[0]
            cnt_max = contour_rng_slider.value[1]
            params = [thresh_filter,cnt_min,cnt_max]
            if radio_button_gp.active == 6:
                off = offset_spinner.value
                block_size = block_spinner.value
                hdr = hdr + ',local offset,local block size'
                params.append(off,block_size)
            params_fname = 'ContourParams.csv'
            params_out = os.path.join(source.data['output_dir'][0],params_fname)
            overwrite = 'no'
            if os.path.exists(params_out):
                overwrite = overwrite_file()
            if overwrite == 'yes' or not os.path.exists(params_out):
                np.savetxt(params_out,np.array([params]),delimiter=',',header=hdr,comments='')
            roi_coords = np.array(source.data["roi_coords"][0])
            out_fname = 'ROI_coords.csv'
            out_fullpath = os.path.join(source.data['output_dir'][0],out_fname)
            if overwrite == 'yes' or not os.path.exists(out_fullpath):
                hdr = 'width,height,x,y'
                np.savetxt(out_fullpath,roi_coords,delimiter=',',header=hdr,comments='')
                print('Successfully saved ROIs coordinates as: '+out_fullpath,end='\r')
                source.data['roi_coords'] = [roi_coords]

    # function for loading previously made files or error handling for going out of order

    def check_ROI_files():
        coords_file = os.path.join(source.data["output_dir"][0],'ROI_coords.csv')
        n_cnt_curr = source.data["num_contours"][0]
        roi_coords_curr = source.data["roi_coords"][0]
        if os.path.exists(coords_file):
            df_tmp=pd.read_csv(coords_file, sep=',')
            roi_coords = np.array(df_tmp.values)
            n_cnt = len(roi_coords)
            if n_cnt_curr != n_cnt or not np.array_equal(roi_coords_curr,roi_coords):
                print('Current ROIs are different from saved ROI file! ROIs from saved file will be used instead and plot updated.',end="\r")
            source.data["num_contours"] = [n_cnt]
            source.data["roi_coords"] = [roi_coords]
            params_file = os.path.join(source.data['output_dir'][0],'ContourParams.csv')
            params_df = pd.read_csv(params_file)
            thresh_ind = params_df["threshold filter"].values[0]
            radio_button_gp.active = int(thresh_ind)
            if thresh_ind == 6:
                offset_spinner.value = int(params_df["local offset"].values[0])
                block_spinner.value = int(params_df["local block size"].values[0])
            contour_rng_slider.value = tuple([int(params_df["contour min"].values[0]),int(params_df["contour max"].values[0])])
            find_contours()
        else:
            print("ROI files not found! Check save directory or export ROIs.{:<100}".format(' '),end="\r")

    # use FFMPEG to crop out regions from original mp4 and save to file

    def create_ROI_movies():
        if source.data['input_dir'][0] == '':
            print('No image loaded! Load image first.{:<100}'.format(' '),end="\r")
        else:
            check_ROI_files()
            side = 100 # for square ROIs, replace first two crop parameters with side & uncomment
            padding = 20
            roi_coords_file = os.path.join(source.data['output_dir'][0],'ROI_coords.csv')
            if source.data["num_contours"]==[0]:
                print("No contours found! Find contours first.{:<100}".format(' '),end="\r")
            elif not os.path.exists(roi_coords_file):
                print("ROI file does not exist! Check save directory or export ROIs.{:<100}".format(' '),end="\r")
            else:
                print('Creating Movies...{:<100}'.format(' '),end="\r")
                pbar = tqdm(total=source.data["num_contours"][0])
                for i in range(source.data["num_contours"][0]):
                    roi_coords = np.array(source.data["roi_coords"][0])
                    inPath = os.path.join(source.data['input_dir'][0],source.data['img_name'][0])
                    #out_fname = source.data['filename'][0].split('.')[0] +'_ROI'+str(i+1)+'.mp4'
                    out_fname = 'ROI'+str(i+1)+'.mp4'
                    outPath = os.path.join(source.data['output_dir'][0],out_fname)
                    #command = f"ffmpeg -i \'{(inPath)}\' -vf \"crop={(roi_coords[i,0]+padding*2)}:{(roi_coords[i,1]+padding*2)}:{(roi_coords[i,2]-padding)}:{(roi_coords[i,3]+padding)}\" -y \'{(outPath)}\'"
                    command = f"ffmpeg -i \'{(inPath)}\' -vf \"crop={side}:{side}:{(roi_coords[i,2]-padding)}:{(roi_coords[i,3]-padding)}\" -y \'{(outPath)}\'"
                    overwrite = 'no'
                    if os.path.exists(outPath):
                        overwrite = overwrite_file()
                    if overwrite == 'yes' or not os.path.exists(outPath):
                        saved = subprocess.check_call(command,shell=True)
                        if saved != 0:
                            print('An error occurred while creating movies! Check terminal window.{:<100}'.format(' '),end="\r")
                    pbar.update()

    # change the display range on images from slider values

    def update_image():
        cmap.low = display_range_slider.value[0]
        cmap.high = display_range_slider.value[1]

    # create statistics files for each mp4 region specific file

    def process_ROIs():
        if source.data['input_dir'][0] == '':
            print('No image loaded! Load image first.{:<100}'.format(' '),end="\r")
        else:
            check_ROI_files()
            hdr = 'roi,time,area,mean,variance,min,max,median,skewness,kurtosis,rawDensity,COMx,COMy'
            cols = hdr.split(',')
            all_stats = np.zeros((0,13))
            n_cnt = source.data["num_contours"][0]
            if n_cnt == 0:
                print("No contours found! Find contours first.{:<100}".format(' '),end="\r")
            for i in range(n_cnt): 
                #in_fname = source.data['filename'][0].split('.')[0] +'_ROI'+str(i+1)+'.mp4'
                in_fname = 'ROI'+str(i+1)+'.mp4'
                inPath = os.path.join(source.data['output_dir'][0],in_fname)
                #out_fname = source.data['filename'][0].split('.')[0] +'_ROI'+str(i+1)+'_stats.csv'
                out_fname = 'stats_ROI'+str(i+1)+'.csv'
                outPath = os.path.join(source.data['output_dir'][0],out_fname)
                if not os.path.exists(inPath):
                    print('ROI movie not found! Create ROI movie first.{:<100}'.format(' '),end="\r")
                    break
                vidcap = cv2.VideoCapture(inPath)
                last_frame = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
                if i==0:
                    pbar = tqdm(total=last_frame*n_cnt)
                success,frame = vidcap.read()
                img_tmp,_,__ = cv2.split(frame)
                h,w = np.shape(img_tmp)
                img = np.zeros((last_frame,h,w))
                img_stats = np.zeros((last_frame,13))
                stats = describe(img_tmp,axis=None)
                median = np.median(img_tmp)
                density = np.sum(img_tmp)
                cx, cy = center_of_mass(img_tmp)
                img_stats[0,0:13] = [i,0,stats.nobs,stats.mean,stats.variance,
                        stats.minmax[0],stats.minmax[1],median,stats.skewness,
                        stats.kurtosis,density,cx,cy]
                img[0,:,:] = np.flipud(img_tmp)
                pbar.update()
                overwrite = 'no'
                if os.path.exists(outPath):
                    overwrite = overwrite_file()
                    if overwrite=='no':
                        pbar.update(last_frame-1)
                if overwrite == 'yes' or not os.path.exists(outPath):
                    for j in range(1,last_frame):
                        vidcap.set(1, j)
                        success,frame = vidcap.read()
                        img_tmp,_,__ = cv2.split(frame)
                        stats = describe(img_tmp,axis=None)
                        t = j*5/60
                        density = np.sum(img_tmp)
                        cx, cy = center_of_mass(img_tmp)
                        median = np.median(img_tmp)
                        img_stats[j,0:13] = [i,t,stats.nobs,stats.mean,stats.variance,
                                stats.minmax[0],stats.minmax[1],median,stats.skewness,
                                stats.kurtosis,density,cx,cy]
                        img[j,:,:] = np.flipud(img_tmp)
                        pbar.update(1)
                    all_stats = np.append(all_stats,img_stats,axis=0)
                    np.savetxt(outPath,img_stats,delimiter=',',header=hdr,comments='')
                if i==(n_cnt-1):
                    df = pd.DataFrame(all_stats,columns=cols)
                    group = df.groupby('roi')
                    for i in range(len(group)):
                        sources_stats[i] = ColumnDataSource(group.get_group(i))

    # load statistics CSVs and first ROI mp4 files and display in plots

    def load_ROI_files():
        if source.data['input_dir'][0] == '':
            print('No image loaded! Load image first.{:<100}'.format(' '),end="\r")
        else:
            check_ROI_files()
            n_cnt = source.data["num_contours"][0]
            basepath = os.path.join(source.data["output_dir"][0],'stats')
            all_files = [basepath+'_ROI'+str(i+1)+'.csv' for i in range(n_cnt)]
            files_exist = [os.path.exists(f) for f in all_files]
            if all(files_exist) and n_cnt != 0:
                df = pd.concat((pd.read_csv(f) for f in all_files), ignore_index=True)
                group = df.groupby('roi')
                OPTIONS = []
                LABELS = []
                pbar = tqdm(total=len(stats)*len(group))
                j=0
                colors_ordered = list(Category20[20])
                idx_reorder = np.append(np.linspace(0,18,10),np.linspace(1,19,10))
                idx = idx_reorder.astype(int)
                colors = [colors_ordered[i] for i in idx]
                for roi, df_roi in group:
                    sources_stats[roi] = ColumnDataSource(df_roi)
                    OPTIONS.append([str(int(roi)+1),'ROI '+(str(int(roi)+1))])
                    LABELS.append('ROI '+str(int(roi)+1))
                    color = colors[j]
                    j+=1
                    if j>=20:
                        j=0
                    for i in range(3,len(df.columns)):
                        name = 'ROI '+str(int(roi)+1)
                        plot_check = p_stats[i-3].select_one({'name':name})
                        if not plot_check:
                            p_stats[i-3].line(x='time',y=str(df.columns[i]),source=sources_stats[roi],
                                name=name,visible=False,line_color=color)
                            p_stats[i-3].xaxis.axis_label = "Time [h]"
                            p_stats[i-3].yaxis.axis_label = str(df.columns[i])
                            p_stats[i-3].add_tools(HoverTool(tooltips=TOOLTIPS))
                            p_stats[i-3].toolbar_location = "left"
                        pbar.update(1)
                ROI_multi_select.options = OPTIONS 
                ROI_multi_select.value = ["1"]
                ROI_movie_radio_group.labels = LABELS
                ROI_movie_radio_group.active = 0
            else:
                print('Not enough files! Check save directory or calculate new stats.{:<100}'.format(' '),end="\r")

    # show/hide curves from selected/deselected labels for ROIs in statistics plots

    def update_ROI_plots():
        n_cnt = source.data["num_contours"][0]
        pbar = tqdm(total=len(stats)*n_cnt)
        for j in range(n_cnt):
            for i in range(len(stats)):
                name = 'ROI '+str(int(j)+1)
                glyph = p_stats[i].select_one({'name': name})
                if str(j+1) in ROI_multi_select.value:
                    glyph.visible = True
                else:
                    glyph.visible = False
                pbar.update(1)

    # load and display the selected ROI's mp4

    def load_ROI_movie():
        idx = ROI_movie_radio_group.active
        in_fname = 'ROI'+str(idx+1)+'.mp4'
        inPath = os.path.join(source.data['output_dir'][0],in_fname)
        if not os.path.exists(inPath):
            print('ROI movie not found! Check save directory or create ROI movie.',end="\r")
        else:
            old_plot = p_ROI.select_one({'name': sourceROI.data['img_name'][0]})
            if old_plot:
                p_ROI.renderers.remove(old_plot)
            vidcap = cv2.VideoCapture(inPath)
            last_frame = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
            ROI_movie_slider.end = (last_frame-1)*5/60
            ROI_movie_slider.value = 0
            vidcap.set(1, 0)
            success,frame = vidcap.read()
            img_tmp,_,__ = cv2.split(frame)
            h,w = np.shape(img_tmp)
            img = np.flipud(img_tmp)
            name = 'ROI'+str(idx+1)
            sourceROI.data = dict(image=[img],x=[0], y=[0], dw=[w], dh=[h],
                img_name=[name])
            p_ROI.image(source=sourceROI, image='image', x='x', y='y', 
               dw='dw', dh='dh', color_mapper=cmap, name='img_name')

    # change the displayed frame from slider movement

    def update_ROI_movie():
        frame_idx = round(ROI_movie_slider.value*60/5)
        in_fname = sourceROI.data['img_name'][0]+'.mp4'
        inPath = os.path.join(source.data['output_dir'][0],in_fname)
        vidcap = cv2.VideoCapture(inPath)
        vidcap.set(1, frame_idx)
        success,frame = vidcap.read()
        img_tmp,_,__ = cv2.split(frame)
        img = np.flipud(img_tmp)
        sourceROI.data['image'] = [img]

    # the following 2 functions are used to animate the mp4

    def update_ROI_slider():
        time = ROI_movie_slider.value + 5/60
        end = ROI_movie_slider.end
        if time > end:
            animate_ROI_movie()
        else:
            ROI_movie_slider.value = time
        return callback_id

    def animate_ROI_movie():
        global callback_id
        if ROI_movie_play_button.label == '► Play':
            ROI_movie_play_button.label = '❚❚ Pause'
            callback_id = curdoc().add_periodic_callback(update_ROI_slider, 10)
        else:
            ROI_movie_play_button.label = '► Play'
            curdoc().remove_periodic_callback(callback_id)
        return callback_id

### Application Content ###

    # main plot for segmentation and contour finding

    cmap = LinearColorMapper(palette="Greys256", low=0, high=255)
    TOOLS = "pan,wheel_zoom,box_zoom,reset,save,box_select,lasso_select"
    IMG_TOOLTIPS = [('name', "@img_name"),("x", "$x"),("y", "$y"),("value", "@image")]

    source = ColumnDataSource(data=dict(image=[0],bin_img=[0],image_orig=[0],
            x=[0], y=[0], dw=[0], dh=[0], num_contours=[0], roi_coords=[0],
            input_dir=[''],output_dir=[''],img_name=['']))
    source_label = ColumnDataSource(data=dict(x=[0], y=[0], label=['']))
    source_contours = ColumnDataSource(data=dict(xs=[0], ys=[0]))

    roi_labels = LabelSet(x='x', y='y', text='label',source=source_label, 
        level='annotation',text_color='white',text_font_size='12pt')

    # create a new plot and add a renderer
    p = figure(tools=TOOLS, toolbar_location=("right"))
    p.add_layout(roi_labels)
    p.x_range.range_padding = p.y_range.range_padding = 0

    # turn off gridlines
    p.xgrid.grid_line_color = None
    p.ygrid.grid_line_color = None
    p.axis.visible = False


    # ROI plots 

    sourceROI = ColumnDataSource(data=dict(image=[0],
            x=[0], y=[0], dw=[0], dh=[0], img_name=[0]))
    sources_stats = {}

    TOOLTIPS = [('name','$name'),('time', '@time'),('stat', "$y")]
    stats = np.array(['mean','var','min','max','median','skew','kurt','rawDensity','COMx','COMy'])
    p_stats = []
    tabs = []
    for i in range(len(stats)):
        p_stats.append(figure(tools=TOOLS, plot_height=300, plot_width=600))
        p_stats[i].x_range.range_padding = p_stats[i].y_range.range_padding = 0
        tabs.append(Panel(child=p_stats[i], title=stats[i]))

    # create a new plot and add a renderer
    p_ROI = figure(tools=TOOLS, toolbar_location=("right"), plot_height=300, plot_width=300)
    p_ROI.x_range.range_padding = p_ROI.y_range.range_padding = 0  

    # turn off gridlines
    p_ROI.xgrid.grid_line_color = p_ROI.ygrid.grid_line_color = None
    p_ROI.axis.visible = False


    # Widgets - Buttons, Sliders, Text, Etc.

    intro = Div(text="""<h2>Droplet Recognition and Analysis with Bokeh</h2> 
        This application is designed to help segment a grayscale image into 
        regions of interest (ROIs) and perform analysis on those regions.<br>
        <h4>How to Use This Application:</h4>
        <ol>
        <li>Load in a grayscale mp4 file and choose a save directory.</li>
        <li>Apply various filters for thresholding. Use <b>Close</b>, <b>Dilate</b> 
        and <b>Erode</b> buttons to adjust each binary image further.</li>
        <li>Use <b>Find Contours</b> button to search the image for closed shape. 
        The <b>Contour Size Range</b> slider will change size of the perimeter to
        be identified. You can apply new thresholds and repeat until satisfied with
        the region selection. Total regions detected is displayed next to
        the button.</li>
        <li>When satisfied, use <b>Export ROIs</b> to write ROI locations and 
        contour finding parameters to file.</li>
        <li><b>Create ROI Movies</b> to write mp4s of the selected regions.</li>
        <li>Use <b>Calculate ROI Stats</b> to perform calculations on the 
        newly created mp4 files.</li>
        <li>Finally, use <b>Load ROI Files</b> to load in the data that you just
        created and view the plots. The statistics plots can be overlaid by 
        selecting multiple labels. Individual ROI mp4s can be animated or you can
        use the slider to move through the frames.</li>
        </ol>
        Note: messages and progress bars are displayed below the GUI.""",
        style={'font-size':'10pt'},width=1000)

    file_button = Button(label="Choose File",button_type="primary")
    file_button.on_click(update_filename)
    inFile = PreText(text='Input File:\n'+source.data["img_name"][0], background=(255,255,255,0.5), width=500)

    filter_LABELS = ["Original","OTSU", "Isodata", "Mean", "Li","Yen","Local"]
    radio_button_gp = RadioButtonGroup(labels=filter_LABELS, active=0, width=600)
    radio_button_gp.on_change('active', lambda attr, old, new: apply_filter())
    
    offset_spinner = Spinner(low=0, high=500, value=1, step=1, width=100, title="Local: Offset",
                            background=(255,255,255,0.5))
    offset_spinner.on_change('value', lambda attr, old, new: apply_filter())
    block_spinner = Spinner(low=1, high=101, value=25, step=2, width=100, title="Local: Block Size",
                           background=(255,255,255,0.5))
    block_spinner.on_change('value', lambda attr, old, new: apply_filter())
    
    closing_button = Button(label="Close",button_type="default", width=100)
    closing_button.on_click(close_img)
    dilation_button = Button(label="Dilate",button_type="default", width=100)
    dilation_button.on_click(dilate_img)
    erosion_button = Button(label="Erode",button_type="default", width=100)
    erosion_button.on_click(erode_img)

    contour_rng_slider = RangeSlider(start=10, end=500, value=(200,350), step=1, width=300, 
            title="Contour Size Range", background=(255,255,255,0.5), bar_color='gray')
    contour_button = Button(label="Find Contours", button_type="success")
    contour_button.on_click(find_contours)
    contours_found = PreText(text='Droplets Detected: '+str(source.data["num_contours"][0]), background=(255,255,255,0.5))
    
    exportROIs_button = Button(label="Export ROIs", button_type="success", width=200)
    exportROIs_button.on_click(export_ROIs)    

    changeDir_button = Button(label="Change Directory",button_type="primary", width=150)
    changeDir_button.on_click(change_directory)
    outDir = PreText(text='Save Directory:\n'+source.data["output_dir"][0], background=(255,255,255,0.5), width=500)

    create_ROIs_button = Button(label="Create ROI Movies",button_type="success", width=200)
    create_ROIs_button.on_click(create_ROI_movies)

    process_ROIs_button = Button(label="Calculate ROI Stats",button_type="success")
    process_ROIs_button.on_click(process_ROIs)

    display_rng_text = figure(title="Display Range", title_location="left", 
                        width=40, height=300, toolbar_location=None, min_border=0, 
                        outline_line_color=None)
    display_rng_text.title.align="center"
    display_rng_text.title.text_font_size = '10pt'
    display_rng_text.x_range.range_padding = display_rng_text.y_range.range_padding = 0

    display_range_slider = RangeSlider(start=0, end=255, value=(0,255), step=1, 
        orientation='vertical', direction='rtl', 
        bar_color='gray', width=40, height=300, tooltips=True)
    display_range_slider.on_change('value', lambda attr, old, new: update_image())

    load_ROIfiles_button = Button(label="Load ROI Files",button_type="primary")
    load_ROIfiles_button.on_click(load_ROI_files)

    ROI_multi_select = MultiSelect(value=[], width=100, height=300)
    ROI_multi_select.on_change('value', lambda attr, old, new: update_ROI_plots())

    ROI_movie_radio_group = RadioGroup(labels=[],width=60)
    ROI_movie_radio_group.on_change('active', lambda attr, old, new: load_ROI_movie())
    ROI_movie_slider = Slider(start=0,end=100,value=0,step=5/60,title="Time [h]", width=280)
    ROI_movie_slider.on_change('value', lambda attr, old, new: update_ROI_movie())

    callback_id = None

    ROI_movie_play_button = Button(label='► Play',width=50)
    ROI_movie_play_button.on_click(animate_ROI_movie)

# initialize some data without having to choose file
    # fname = os.path.join(os.getcwd(),'data','Droplets.mp4')
    # load_data(filename=fname)


### Layout & Initialize application ###

    ROI_layout = layout([
        [ROI_movie_radio_group, p_ROI],
        [ROI_movie_slider,ROI_movie_play_button]
        ])

    app = layout(children=[
        [intro],
        [file_button,inFile],
        [radio_button_gp, offset_spinner, block_spinner],
        [closing_button, dilation_button, erosion_button],
        [contour_rng_slider, contour_button, contours_found],
        [exportROIs_button, outDir, changeDir_button],
        [create_ROIs_button, process_ROIs_button],
        [display_rng_text, display_range_slider, p],
        [load_ROIfiles_button],
        [ROI_layout, ROI_multi_select, Tabs(tabs=tabs)]
    ])
    
    doc.add_root(app)
示例#19
0
df = load_data()


##########
##########    Preprocess data
##########
source = preprocess_data(df)


##########
##########    Create checkerbox filter
##########
years_list = sorted(df['CalendarYearIssued'].astype(str).unique().tolist())

year_selection = MultiSelect(title='Year'
                            , value=[str(i) for i,j in enumerate(years_list)]
                            , options = [(str(i),j) for i,j in enumerate(years_list)])


##########
##########    Create select all button
##########
def update_selectall():
    year_selection.value = [x[0] for x in year_selection.options]

select_all = Button(label='Select All')
select_all.on_click(update_selectall)


##########
##########    Create run button
def create_world_cases_time_series_tab():
    ## Data Sources
    source_df, source_CDS = get_country_cases_vs_time()

    ## Line Plots
    line_figure = figure(
        x_axis_type='datetime',
        y_axis_type='log',
        title='World Confirmed Cases by Region',
        x_axis_label='Date',
        y_axis_label='Number of Confirmed Cases (Logarithmic Scale)',
        active_scroll='wheel_zoom')

    starting_regions = ['China', 'US', 'Italy']
    excluded_columns_set = {'index', 'date'}

    doubling_lines_props = {
        'alpha': 0.6,
        'muted_alpha': 0.2,
        'line_width': 3,
        'source': source_CDS,
        'x': 'date',
        'visible': True
    }

    for number, text, color in zip([4, 7, 14], ['four', 'seven', 'fourteen'],
                                   gray(6)[2:5]):
        column_name = f'{text}_day_doubling'
        excluded_columns_set.add(column_name)
        source_CDS.data[column_name] = 2**(
            np.arange(len(source_CDS.data['index'])) / number)
        line_figure.line(y=column_name,
                         legend_label=f'{number}-day Doubling Time',
                         line_color=color,
                         name=column_name,
                         **doubling_lines_props)

    line_params = {
        'x': 'date',
        'source': source_CDS,
        'line_width': 4,
        'alpha': 0.6
    }

    lines = {
        key: line_figure.line(y=key, name=key, line_color=color, **line_params)
        for key, color in zip(starting_regions, viridis(len(starting_regions)))
    }

    line_figure.legend.location = 'top_left'
    line_figure.legend.click_policy = 'hide'

    hover_tool = HoverTool(
        tooltips=[('Date', '@date{%F}'), ('Region', '$name'),
                  ('Number of Cases', '@$name{0,0}')],
        formatters={
            '@date': 'datetime',
        },
        renderers=[*line_figure.renderers],
        # mode='vline'
    )

    line_figure.add_tools(hover_tool)

    ## Region Selector
    labels = [
        key for key in source_CDS.data.keys()
        if key not in excluded_columns_set
    ]

    def region_select_callback(attr, old, new):
        new_lines = set(new) - set(old)
        old_lines = set(old) - set(new)

        for key in old_lines:
            lines[key].visible = False

        for key in new_lines:
            if key in lines.keys():
                lines[key].visible = True
            else:
                lines[key] = line_figure.line(
                    y=key,
                    name=key,
                    line_color=np.random.choice(Viridis256),
                    **line_params)
                hover_tool.renderers = [*hover_tool.renderers, lines[key]]

    region_select = MultiSelect(title='Select Regions to Show',
                                value=starting_regions,
                                options=labels,
                                sizing_mode='stretch_height')
    region_select.on_change('value', region_select_callback)

    ## Create Layout
    child = row([
        column([line_figure]),
        column([region_select]),
    ])

    return Panel(child=child, title='World Cases Time Series')
示例#21
0
    bottom_row.children[0] = cost_plot


# Parse arguments and connect to db
parser = create_parser()
args = sys.argv[1:]
parsed_args = parser.parse_args(args=args)
conn = connect_to_database(db_path=parsed_args.database)

# Set Up Data
data = DataProvider(conn)

# Set up selection widgets
scenario_select = MultiSelect(
    title="Select Scenario(s):",
    value=data.scenario_options,
    # width=600,
    options=data.scenario_options)
period_select = MultiSelect(title="Select Period(s):",
                            value=data.period_options,
                            options=data.period_options)
stage_select = Select(title="Select Stage:",
                      value=data.stage_options[0],
                      options=data.stage_options)
zone_select = Select(title="Select Load Zone:",
                     value=data.zone_options[0],
                     options=data.zone_options)
capacity_select = Select(title="Select Capacity Metric:",
                         value=data.cap_options[2],
                         options=data.cap_options)
示例#22
0
    return (settings_files_select_options)


make_cases_div = Div(text="<h2>Make case(s)</h2>", sizing_mode="stretch_width")

# Settings file picker
make_cases_select_div = Div(text="Select one settings file.",
                            css_classes=['item-description'],
                            sizing_mode='stretch_width')

# List for picking a file
get_settings_files()
settings_files_select_options = retrieve_select_options(settings_file_names)

make_cases_select = MultiSelect(value=[],
                                options=settings_files_select_options)

# Make cases button
make_cases_button = Button(label="Make case(s)",
                           button_type='primary',
                           disabled=False)
refresh_button = Button(label="Refresh files",
                        button_type='default',
                        disabled=False)

make_cases_output = Div(text="Click the button above to create and build " +
                        "the cases specified in the chosen settings file.",
                        css_classes=["item-description"],
                        style={
                            'overflow': 'auto',
                            'width': '100%',
示例#23
0
class Selector:
    def __init__(
        self,
        name="Specials",
        descr="Choose one",
        kind="specials",
        css_classes=[],
        entries={},
        default="",
        title=None,
        none_allowed=False,
    ):
        self.name = name
        self.descr = descr
        self.entries = entries
        self.kind = kind
        self.css_classes = css_classes
        options = sorted(entries.keys())
        if none_allowed:
            options = ["None"] + options
        if title is None:
            title = "."
            css_classes = ["deli-selector", "hide-title"]
        else:
            css_classes = ["deli-selector"]
        self.widget = MultiSelect(
            options=options,
            value=[default],
            # height=150,
            size=8,
            name="deli-selector",
            title=title,
            css_classes=css_classes,
        )

        # HACK: force MultiSelect to only have 1 value selected
        def multi_select_hack(attr, old, new):
            if len(new) > 1:
                self.widget.value = old

        self.widget.on_change("value", multi_select_hack)

    @property
    def value(self):
        # HACK: This is because we are useing MultiSelect instead of Select
        return self.widget.value[0]

    def layout(self, additional_widgets=[], width=None):
        title = Div(
            text="""<h2>{0}</h2><h3>{1}</h3>""".format(self.name, self.descr),
            css_classes=["controls"],
        )
        footer = Div(
            text="""<a href="#">About the {0}</a>""".format(self.kind),
            css_classes=["controls", "controls-footer"],
        )
        if width is None:
            width = 160 * (1 + len(additional_widgets))
        return column(
            title,
            row(
                self.widget,
                *additional_widgets,
                width=width,
                css_classes=["controls"],
            ),
            footer,
            css_classes=self.css_classes,
        )
示例#24
0
def build_graph_plot(G, title=""):
    """ Return a Bokeh plot of the given networkx graph

    Parameters
    ----------
    G: :obj:`networkx.Graph`
        Networkx graph instance to be plotted.
    title: str
        Title of the final plot

    Returns
    -------
    :obj:`bokeh.models.plot`
        Bokeh plot of the graph.
    """

    plot = Plot(plot_width=600,
                plot_height=450,
                x_range=Range1d(-1.1, 1.1),
                y_range=Range1d(-1.1, 1.1))
    plot.title.text = title

    node_attrs = {}
    for node in G.nodes(data=True):
        node_color = Spectral4[node[1]['n']]
        node_attrs[node[0]] = node_color
    nx.set_node_attributes(G, node_attrs, "node_color")

    node_hover_tool = HoverTool(tooltips=[("Label", "@label"), ("n", "@n")])
    wheelZoom = WheelZoomTool()
    plot.add_tools(node_hover_tool, PanTool(), wheelZoom, ResetTool())
    plot.toolbar.active_scroll = wheelZoom

    graph_renderer = from_networkx(G,
                                   nx.spring_layout,
                                   k=0.3,
                                   iterations=200,
                                   scale=1,
                                   center=(0, 0))
    graph_renderer.node_renderer.glyph = Circle(size=15,
                                                fill_color="node_color")
    graph_renderer.edge_renderer.glyph = MultiLine(line_alpha=0.8,
                                                   line_width=1)

    plot.renderers.append(graph_renderer)

    selectCallback = CustomJS(args=dict(graph_renderer=graph_renderer),
                              code="""
            let new_data_nodes = Object.assign({},graph_renderer.node_renderer.data_source.data);
            new_data_nodes['node_color'] = {};
            let colors = ['#2b83ba','#ABDDA4','#fdae61'];
            let ns = cb_obj.value.reduce((acc,v)=>{
                if(v=='fullGraph'){
                   acc.push(0);
                   acc.push(1);
                   acc.push(2);
                }
               if(v=='n0')acc.push(0);
               if(v=='n1')acc.push(1);
               if(v=='n2')acc.push(2);
               return acc;
            },[])


            Object.keys(graph_renderer.node_renderer.data_source.data['node_color']).map((n,i)=>{
                new_data_nodes['node_color'][i]='transparent';
            })


             ns.map(n=>{
                Object.keys(graph_renderer.node_renderer.data_source.data['node_color']).map((g,i)=>{
                    if(graph_renderer.node_renderer.data_source.data['n'][i]==n){
                        new_data_nodes['node_color'][i]=colors[n];
                    }
                })
            })

            graph_renderer.node_renderer.data_source.data = new_data_nodes

            """)

    multi_select = MultiSelect(title="Option:",
                               options=[("fullGraph", "Full Graph"),
                                        ("n0", "Seed Nodes"), ("n1", "N1"),
                                        ("n2", "N2")])
    multi_select.js_on_change('value', selectCallback)

    return column(plot, multi_select)
BSURF_TEXT = PreText(text='', width=500)
EDOT_TEXT = PreText(text='', width=500)
status = PreText(text="""TODO:
1. Make P-Pdot plot more useful.
2. Add user comments.
3. Discuss classification tags.
4. Feedback.
5. Add histograms.
""",
                 width=500)

CATALOGUE = Toggle(label="Toggle catalogue", width=300, button_type="success")
UPDATE = Button(label="Update", width=300)

PROFILE = MultiSelect(title="Profile tags",
                      options=tags["PROFILE"],
                      value=["Unclassified"],
                      height=200)
POL = MultiSelect(title="Polarization tags",
                  options=tags["POLARIZATION"],
                  value=["Unclassified"],
                  height=200)
FREQ = MultiSelect(title="Frequency tags",
                   options=tags["FREQUENCY"],
                   value=["Unclassified"],
                   height=200)
TIME = MultiSelect(title="Time tags",
                   options=tags["TIME"],
                   value=["Unclassified"],
                   height=200)
OBS = MultiSelect(title="Observation tags",
                  options=tags["OBSERVATION"],
示例#26
0
    'y': y,
    'label': label
})  #create a dataframe for future use

source = ColumnDataSource(data=dict(x=x, y=y, label=label))

plot_figure = figure(title='Multi-Select',
                     height=450,
                     width=600,
                     tools="save,reset",
                     toolbar_location="below")

plot_figure.scatter('x', 'y', color='label', source=source, size=10)

multi_select = MultiSelect(title="Filter Plot by color:",
                           value=["Red", "Orange"],
                           options=[("Red", "Red"), ("Orange", "Orange")])


def multiselect_click(attr, old, new):
    active_mselect = multi_select.value  ##Getting multi-select value

    selected_df = df[df['label'].isin(
        active_mselect)]  #filter the dataframe with value in multi-select

    source.data = dict(x=selected_df.x,
                       y=selected_df.y,
                       label=selected_df.label)


multi_select.on_change('value', multiselect_click)
示例#27
0
    def do_layout(self):
        """
        generates the overall layout by creating all the widgets, buttons etc and arranges
        them in rows and columns
        :return: None
        """
        self.source = self.generate_source()
        tab_plot = self.generate_plot(self.source)
        multi_select = MultiSelect(title="Option (Multiselect Ctrl+Click):",
                                   value=self.active_country_list,
                                   options=countries,
                                   height=500)
        multi_select.on_change('value', self.update_data)
        tab_plot.on_change('active', self.update_tab)
        radio_button_group_per_capita = RadioButtonGroup(
            labels=["Total Cases", "Cases per Million"],
            active=0 if not self.active_per_capita else 1)
        radio_button_group_per_capita.on_click(self.update_capita)
        radio_button_group_scale = RadioButtonGroup(
            labels=[Scale.log.name.title(),
                    Scale.linear.name.title()],
            active=self.active_y_axis_type.value)
        radio_button_group_scale.on_click(self.update_scale_button)
        radio_button_group_df = RadioButtonGroup(labels=[
            Prefix.confirmed.name.title(),
            Prefix.deaths.name.title(),
            Prefix.recovered.name.title(),
        ],
                                                 active=int(
                                                     self.active_case_type))
        radio_button_group_df.on_click(self.update_data_frame)
        refresh_button = Button(label="Refresh Data",
                                button_type="default",
                                width=150)
        refresh_button.on_click(load_data_frames)
        export_button = Button(label="Export Url",
                               button_type="default",
                               width=150)
        export_button.on_click(self.export_url)
        slider = Slider(start=1,
                        end=30,
                        value=self.active_window_size,
                        step=1,
                        title="Window Size for rolling average")
        slider.on_change('value', self.update_window_size)
        radio_button_average = RadioButtonGroup(
            labels=[Average.mean.name.title(),
                    Average.median.name.title()],
            active=self.active_average)
        radio_button_average.on_click(self.update_average_button)
        plot_variables = [
            self.active_plot_raw, self.active_plot_average,
            self.active_plot_trend
        ]
        plots_button_group = CheckboxButtonGroup(
            labels=["Raw", "Averaged", "Trend"],
            active=[i for i, x in enumerate(plot_variables) if x])
        plots_button_group.on_click(self.update_shown_plots)

        world_map = self.create_world_map()
        link_div = Div(
            name="URL",
            text=
            fr'Link <a target="_blank" href="{self.url}">Link to this Plot</a>.',
            width=300,
            height=10,
            align='center')
        footer = Div(
            text=
            """Covid-19 Dashboard created by Andreas Weichslgartner in April 2020 with python, bokeh, pandas, 
            numpy, pyproj, and colorcet. Source Code can be found at 
            <a href="https://github.com/weichslgartner/covid_dashboard/">Github</a>.""",
            width=1600,
            height=10,
            align='center')
        self.generate_table_cumulative()
        columns = [
            TableColumn(field="name", title="Country"),
            TableColumn(field="number_rolling",
                        title="daily avg",
                        formatter=NumberFormatter(format="0.")),
            TableColumn(field="number_daily",
                        title="daily raw",
                        formatter=NumberFormatter(format="0."))
        ]
        top_top_14_new_header = Div(text="Highest confirmed (daily)",
                                    align='center')
        top_top_14_new = DataTable(source=self.top_new_source,
                                   name="Highest confirmed(daily)",
                                   columns=columns,
                                   width=300,
                                   height=380)
        self.generate_table_new()
        columns = [
            TableColumn(field="name", title="Country"),
            TableColumn(field="number",
                        title="confirmed(cumulative)",
                        formatter=NumberFormatter(format="0."))
        ]

        top_top_14_cum_header = Div(text="Highest confirmed (cumulative)",
                                    align='center')
        top_top_14_cum = DataTable(source=self.top_total_source,
                                   name="Highest confirmed(cumulative)",
                                   columns=columns,
                                   width=300,
                                   height=380)
        self.layout = layout([
            row(
                column(tab_plot, world_map),
                column(top_top_14_new_header, top_top_14_new,
                       top_top_14_cum_header, top_top_14_cum),
                column(link_div, row(refresh_button, export_button),
                       radio_button_group_df, radio_button_group_per_capita,
                       plots_button_group, radio_button_group_scale, slider,
                       radio_button_average, multi_select),
            ),
            row(footer)
        ])

        curdoc().add_root(self.layout)
        curdoc().title = "Bokeh Covid-19 Dashboard"
data_substantiated = pd.read_excel('NRC_allegation_stats.xlsx',
                                   sheetname=2).set_index("Site").fillna(False)
data.fillna('')
data_substantiated.fillna('')
fruits = list(dataz['Site'])
years = [str(i) for i in data.columns.values]
palette = list(itertools.islice(palette, len(data.columns.values)))

# this creates [ ("Apples", "2015"), ("Apples", "2016"), ("Apples", "2017"), ("Pears", "2015), ... ]
years_widget = RangeSlider(start=2013,
                           end=2017,
                           value=(2015, 2017),
                           step=1,
                           title="Year[s]")
site_widget = MultiSelect(title="Site[s]",
                          value=["ARKANSAS 1 & 2"],
                          options=open(join(dirname(__file__),
                                            'sites.txt')).read().split('\n'))
site_widget.on_change('value', lambda attr, old, new: update)

# this creates [ ("Apples", "2015"), ("Apples", "2016"), ("Apples", "2017"), ("Pears", "2015), ... ]
x = [(fruit, year) for fruit in fruits for year in years]
source = ColumnDataSource(data=dict(x=[], counts=[]))
source2 = ColumnDataSource(data=dict(x=[], counts=[]))

p = figure(x_range=data.index.values.tolist(),
           plot_height=350,
           title="Allegations Received from All Sources",
           toolbar_location=None,
           tools="")
p.vbar(x='x',
       top='counts',
示例#29
0
checkbox_button_group = CheckboxButtonGroup(labels=["Option 1", "Option 2", "Option 3"], active=[0, 1])
radio_button_group = RadioButtonGroup(labels=["Option 1", "Option 2", "Option 3"], active=0)

checkbox_button_group_vertical = CheckboxButtonGroup(labels=["Option 1", "Option 2", "Option 3"], active=[0, 1], orientation="vertical")
radio_button_group_vertical = RadioButtonGroup(labels=["Option 1", "Option 2", "Option 3"], active=0, orientation="vertical")

text_input = TextInput(placeholder="Enter value ...")

completions = ["aaa", "aab", "aac", "baa", "caa"]
autocomplete_input = AutocompleteInput(placeholder="Enter value (auto-complete) ...", completions=completions)

text_area = TextAreaInput(placeholder="Enter text ...", cols=20, rows=10, value="uuu")

select = Select(options=["Option 1", "Option 2", "Option 3"])

multi_select = MultiSelect(options=["Option %d" % (i+1) for i in range(16)], size=6)

multi_choice = MultiChoice(options=["Option %d" % (i+1) for i in range(16)])

slider = Slider(value=10, start=0, end=100, step=0.5)

range_slider = RangeSlider(value=[10, 90], start=0, end=100, step=0.5)

date_slider = DateSlider(value=date(2016, 1, 1), start=date(2015, 1, 1), end=date(2017, 12, 31))

date_range_slider = DateRangeSlider(value=(date(2016, 1, 1), date(2016, 12, 31)), start=date(2015, 1, 1), end=date(2017, 12, 31))

spinner = Spinner(value=100)

color_picker = ColorPicker(color="red", title="Choose color:")
示例#30
0
    def __init__(self):
        self.spe_file = None
        self.full_sensor_data = None
        self.selection_data = None

        # setup widgets
        self.directory_input = TextInput(placeholder='Directory',
                                         value=os.path.join(
                                             os.getcwd(), 'data'))
        self.show_files_button = Button(label='Show Files',
                                        button_type='primary')
        self.file_view = MultiSelect(size=5)
        self.open_file_button = Button(label='Open File',
                                       button_type='warning')
        self.update_selection_button = Button(label='Update Selection',
                                              button_type='success')
        self.selection_range = RangeSlider(start=0,
                                           end=1,
                                           value=(0, 1),
                                           step=1,
                                           title='Selected Rows')

        # connect button callbacks
        self.show_files_button.on_click(self.update_file_browser)
        self.open_file_button.on_click(self.open_file_callback)
        self.update_selection_button.on_click(self.update_selection)
        self.selection_range.on_change('value', self.selection_range_callback)

        # setup plots
        self.full_sensor_image = figure(x_range=(0, 1),
                                        y_range=(0, 1023),
                                        tools='pan,box_zoom,wheel_zoom,reset',
                                        plot_width=512,
                                        plot_height=512)
        self.full_sensor_image_label = Label(x=0.1,
                                             y=0.1,
                                             text='Source Data',
                                             text_font_size='36pt',
                                             text_color='#eeeeee')
        self.full_sensor_image.add_tools(BoxSelectTool(dimensions='height'))
        self.full_sensor_image.grid.grid_line_color = None
        self.full_sensor_image.xaxis.major_tick_line_color = None
        self.full_sensor_image.xaxis.minor_tick_line_color = None
        self.full_sensor_image.yaxis.major_tick_line_color = None
        self.full_sensor_image.yaxis.minor_tick_line_color = None
        self.full_sensor_image.xaxis.major_label_text_font_size = '0pt'
        self.full_sensor_image.yaxis.major_label_text_font_size = '0pt'
        self.selection_lines_coords = ColumnDataSource(
            data=dict(x=[[0, 1], [0, 1]], y=[[0, 0], [1, 1]]))

        self.selection_image = figure(x_range=(0, 1),
                                      y_range=(0, 1),
                                      tools='wheel_zoom',
                                      plot_width=1024,
                                      plot_height=180)
        self.selection_image_label = Label(x=0.1,
                                           y=0.2,
                                           text='Selection Region',
                                           text_font_size='36pt',
                                           text_color='#eeeeee')
        self.selection_image.grid.grid_line_color = None

        # build the layout
        controls = [
            self.directory_input, self.show_files_button, self.file_view,
            self.open_file_button, self.selection_range,
            self.update_selection_button
        ]
        widgets = widgetbox(*controls, width=500)
        self.layout = layout(
            children=[[widgets],
                      [self.full_sensor_image, self.selection_image]],
            sizing_mode='fixed')

        # set defaults
        self.initialize_ui()
示例#31
0
class Trends:
    """Trends layout
    """
    def __init__(self, palette=Purples[3]):
        self.cases = LinePlot(ARIMA_CASES_TABLE)
        self.cases.render_figure()
        self.cases.title("Cumulative Cases by State")
        self.cases.axis_label('Date', 'Cases')
        self.cases.color_palette(palette)

        LOG.debug('state cases')

        self.deaths = LinePlot(ARIMA_DEATHS_TABLE)
        self.deaths.render_figure()
        self.deaths.title("Cumulative Deaths by State")
        self.deaths.axis_label('Date', 'Deaths')
        self.deaths.color_palette(palette)

        LOG.debug('state deaths')

        self.multiselect = None
        self._add_multiselect()
        self.multiselect.value = ['12', '34', '36']

        LOG.debug('render default states')

    def _add_multiselect(self):
        self.multiselect = MultiSelect(title='States:', value=['01'],
                                       options=self.cases.options)
        self.multiselect.max_width = 170
        self.multiselect.min_height = 500 - 47
        self.multiselect.on_change('value', self._callback_cases)
        self.multiselect.on_change('value', self._callback_deaths)

    def _callback_cases(self, _attr, _old, new):
        for _id, _ in list(self.multiselect.options):
            if self.cases.actual[_id].visible:
                self.cases.actual[_id].visible = False
                self.cases.predict[_id].visible = False
                self.cases.lower[_id].visible = False
                self.cases.upper[_id].visible = False
                self.cases.area[_id].visible = False

        for _id in new:
            if not self.cases.actual[_id].visible:
                _slice = self.cases.data.loc[_id, :]
                self.cases.source[_id].data = ColumnDataSource.from_df(data=_slice)

                self.cases.actual[_id].visible = True
                self.cases.predict[_id].visible = True
                self.cases.lower[_id].visible = True
                self.cases.upper[_id].visible = True
                self.cases.area[_id].visible = True

    def _callback_deaths(self, _attr, _old, new):
        for _id, _ in list(self.multiselect.options):
            if self.deaths.actual[_id].visible:
                self.deaths.actual[_id].visible = False
                self.deaths.predict[_id].visible = False
                self.deaths.lower[_id].visible = False
                self.deaths.upper[_id].visible = False
                self.deaths.area[_id].visible = False

        for _id in new:
            if not self.deaths.actual[_id].visible:
                _slice = self.deaths.data.loc[_id, :]
                self.deaths.source[_id].data = ColumnDataSource.from_df(data=_slice)

                self.deaths.actual[_id].visible = True
                self.deaths.predict[_id].visible = True
                self.deaths.lower[_id].visible = True
                self.deaths.upper[_id].visible = True
                self.deaths.area[_id].visible = True

    def layout(self):
        """Build trend layout

        Returns:
            Bokeh Layout -- layout with cases, deaths and state selection
        """
        _graphs = gridplot([self.cases.plot, self.deaths.plot], ncols=1,
                           plot_width=800 - self.multiselect.max_width,
                           plot_height=250, toolbar_location=None)

        _layout = row(_graphs, self.multiselect)

        return _layout