def getTimeSeries_works(df_tws):
    output_file("js_on_change.html")

    x = [x * 0.005 for x in range(0, 200)]
    y = x

    source = ColumnDataSource(data=dict(x=x, y=y))

    plot = Figure(plot_width=400, plot_height=400)
    plot.line('x', 'y', source=source, line_width=3, line_alpha=0.6)

    callback = CustomJS(args=dict(source=source),
                        code="""
	    var data = source.data;
	    var f = cb_obj.value
	    var x = data['x']
	    var y = data['y']
	    for (var i = 0; i < x.length; i++) {
	        y[i] = Math.pow(x[i], f)
	    }
	    source.change.emit();
	""")

    slider = Slider(start=0.1, end=4, value=1, step=.1, title="power")
    slider.js_on_change('value', callback)

    layout = column(plot, slider)

    return layout
Exemple #2
0
def mg_slider(source, timeList, time, bar_source, status):
    slider = Slider(start=1, end=len(time), value=1, step=1, title="Time")

    with open(dir_path + '/slider.js', 'r') as slider_file:
        slider_code = slider_file.read()

    callback = CustomJS(args=dict(source=source,
                                  slider=slider,
                                  timeList=timeList,
                                  bar_source=bar_source,
                                  status=status),
                        code=slider_code)
    slider.js_on_change('value', callback)
    return slider
Exemple #3
0
def init():
    output_notebook()
    display(Javascript(_init_js))
    but = '<img src="resources/show.png" width="34" height="25" style="display: inline" alt="Slideshow button" title="Enter/Exit RISE Slideshow">'
    txt = Div(text='<h2>You can now start the slideshow!</h3>' +
                   f'<h3 style="margin: 0.5em 0;">Just click the RISE slideshow button above - the one that looks like: {but}<br/>' +
                   '(or you can press alt+R on your keyboard instead if you prefer).</h3>')
    clearbutton = Button(label="Clear")
    clearbutton.js_on_click(CustomJS(code='primes_clear();'))
    cleartext = Paragraph(text='Clear all plots and outputs (e.g. before restarting slideshow).')
    increm = Toggle(label='Incremental', active=True)
    increm.js_on_click(CustomJS(code='primes_incremental(cb_obj.active)'))
    incremtext = Paragraph(text='Update timing plots incrementally (disable for static slide show).')
    repeats = Slider(start=1, end=10, value=3)
    repeats.js_on_change('value', CustomJS(code='primes_repeats(cb_obj.value)'))
    repeatstext = Paragraph(text='Repeats for timing measurements (higher is more accurate, but slower).')
    controls = layout([[clearbutton, cleartext],
                       [increm, incremtext],
                       [repeats, repeatstext]])
    show(column(txt, controls, sizing_mode='stretch_width'))
Exemple #4
0
def make_interactive_plot(qr: np.array, pvals, disp_data_full: list,
                          disp_inter_full: list, labels: dict) -> None:
    """ Create the interactive Bokeh plot as a standalone html web page

    :param qr: x-values for the points
    :type qr: np.array
    :param pvals: slider values
    :type pvals: list
    :param disp_data_full: list of arrays containing q vectors and eigenvalues for points
    :type disp_data_full: list
    :param disp_inter_full: list of arrays containing interpolated x and eigen values
    :type disp_inter_full: list
    :param labels: labels for the x-axis indicating special points
    :type labels: dict
    :return: None

    .. todo:: add output file name option
    """

    # Plotting
    from bokeh.plotting import figure, output_file
    from bokeh.io import show, output_notebook, push_notebook
    from bokeh.models import ColumnDataSource, Panel, CustomJS, Circle, Line, Div, HoverTool
    from bokeh.models.widgets import CheckboxGroup, Slider, RangeSlider
    from bokeh.layouts import column, row
    from bokeh.application import Application

    nfiles = len(disp_data_full)
    neigvals = disp_inter_full[0].shape[1] - 1  # first column is x-values

    # Output file name (web page)
    output_file("disp_app.html")

    # Create the figure
    p = figure(title="Dispersion curves for FCC",
               y_range=(0, 2.25),
               x_range=(0, qr[-1] + 0.001),
               x_axis_label="q",
               y_axis_label="ω (THz)",
               plot_width=800)

    # Glyphs for points and lines (intepolation)
    raw_glyphs = []
    inter_glyphs = []

    # Make dictionaries for Bokeh ColumnDataSource
    draw = {"xraw": qr}
    dint = {"xint": disp_inter_full[0][:, 0]}
    for ip in range(nfiles):
        for ie in range(neigvals):
            ei = "e" + str(ip) + "_" + str(ie)
            draw[ei] = disp_data_full[ip][:, ie + 3]
            dint[ei] = disp_inter_full[ip][:, ie + 1]

    srcraw = ColumnDataSource(draw)
    srcint = ColumnDataSource(dint)

    # Build glyphs for each pressure point
    for ip in range(nfiles):
        lalpha = 0
        for ie in range(neigvals):
            ei = "e" + str(ip) + "_" + str(ie)
            if ip == 0: lalpha = 1
            raw_glyphs.append(
                Circle(x="xraw", y=ei, line_alpha=lalpha, fill_alpha=lalpha))
            inter_glyphs.append(Line(x="xint", y=ei, line_alpha=lalpha))
            p.add_glyph(srcraw, raw_glyphs[-1])
            p.add_glyph(srcint, inter_glyphs[-1])

    # Styling
    p.xaxis.ticker = list(labels.keys())  # SPqr
    p.xaxis.major_label_overrides = labels

    # Widgets
    EIGEN_LBLS = ["Eigen " + str(i + 1) for i in range(neigvals)]
    eigen_chkbx_grp = CheckboxGroup(labels=EIGEN_LBLS,
                                    active=[i for i in range(neigvals)])
    slider = Slider(start=1,
                    end=nfiles,
                    value=1,
                    step=1,
                    title="Pressure series")
    #pvals = [''.join(re.findall(regex, f)) for f in onlylogfiles ]
    div = Div(text="Pressure = " + pvals[0] + " bar", name="pval")

    # Callbacks for widgets
    ## Slider Callback
    slider.js_on_change(
        "value",
        CustomJS(args=dict(glp=raw_glyphs,
                           glq=inter_glyphs,
                           neig=neigvals,
                           lbl=div,
                           vals=pvals,
                           chkbx=eigen_chkbx_grp),
                 code="""
        var i;
        var eigs = chkbx.active
        for (i=0; i<glp.length; i=i+1){
            if (i>=(this.value-1)*neig && i < this.value*neig){
                if( eigs.includes(i%neig) ){
                    glp[i]["fill_alpha"] = 1;
                    glp[i]["line_alpha"] = 1;
                    glq[i]["line_alpha"] = 1;
                } else {
                    glp[i]["fill_alpha"] = 0;
                    glp[i]["line_alpha"] = 0;
                    glq[i]["line_alpha"] = 0;
                }
                console.log(i,'YES')
            }else{
                glp[i]["fill_alpha"] = 0;
                glp[i]["line_alpha"] = 0;
                glq[i]["line_alpha"] = 0;
                console.log(i,'NO')
            }
        }
        lbl["text"] = "Pressure = " + vals[this.value-1] + " bar"
    """))

    ## Checkbox callback
    eigen_chkbx_grp.js_on_click(
        CustomJS(args=dict(glp=raw_glyphs,
                           glq=inter_glyphs,
                           neig=neigvals,
                           sldr=slider),
                 code="""
        var actv = this.active
        var i
        console.log("sldr ",sldr.value)
        for (i=(sldr.value-1)*neig; i<sldr.value*neig; i=i+1){
            console.log('i, i%neig res = ', i, i%neig, actv.includes(i%neig))
            if (actv.includes(i%neig)){
                glp[i]["fill_alpha"]=1;
                glp[i]["line_alpha"]=1;
                glq[i]["line_alpha"]=1;
            } else {
                glp[i]["fill_alpha"]=0;
                glp[i]["line_alpha"]=0;
                glq[i]["line_alpha"]=0;
            }
        }
    """))

    ## Add hover widget
    hover = HoverTool(tooltips=[("ω (THz)", "$y")])
    p.add_tools(hover)

    ## Make the layout
    layout = column(eigen_chkbx_grp, p, row(slider, div))

    ## Make the webpage
    show(layout)
Exemple #5
0
def similarityGraph(si, sy, ey, db, abstractsim):

    #calculate similarity and return necessary info
    #ids = pmids of search, dates = dates of pmids, authors = authors of pmids, pmccites = number of cites, Y is x,y coordinate
    #from dimensionality reduction, topwords = labels of clusters from kmeans/tfidf, kcenters = x,y centers of clusters from kmeans
    #d is dictionary for all data storage
    d = SimilarityCalc(si, sy, ey, db, abstractsim)

    # for testing purposes
    # with open('data.pk','wb') as f:
    #      pickle.dump(d,f)

    #convert authors list of lists to list of strings for display
    authors_str = []
    for auths in d['authors']:
        authors_str.append(", ".join(auths))

    #calcualte a scaled pt size based on citation quantity
    minw = 8
    maxw = 30

    # with open('citespkl.p','wb') as f:
    #     pickle.dump(list(map(int,pmccites)),f)
    ptsizes = getScaledSizes(list(map(int, d['pmccites'])), minw, maxw)

    #create colors based on years published
    colors = getScaledColors(d['dates'])

    #colors = ['blue']*len(ids)
    alphas = [1] * len(d['ids'])
    source = ColumnDataSource(data=dict(x=d['Y'][:, 0],
                                        y=d['Y'][:, 1],
                                        PMID=d['ids'],
                                        titles=d['titles'],
                                        authors=authors_str,
                                        journals=d['journals'],
                                        dates=d['dates'],
                                        alphas=alphas,
                                        pmccites=d['pmccites'],
                                        ptsizes=ptsizes,
                                        colors=colors,
                                        colorsperm=colors))

    ########publication view table from selected on tsne plot
    pubview_data = dict(titles=["Title"],
                        dates=["Date"],
                        journals=["Journal"],
                        authors=["Author"],
                        pmccites=["PMC Citations"],
                        PMID=["pmids"])

    pubview_source = ColumnDataSource(pubview_data)

    pubview_columns = [
        TableColumn(field="titles", title="Article Title", width=400),
        TableColumn(field="authors", title="Authors", width=50),
        TableColumn(field="journals", title="Journal", width=50),
        TableColumn(field="dates", title="Date", width=80),
        TableColumn(field="pmccites", title="PMC Citations", width=80),
        TableColumn(field="PMID", title="PMIDS", width=0),
    ]

    pubview_table = DataTable(source=pubview_source,
                              columns=pubview_columns,
                              width=930,
                              height=400)

    source.callback = CustomJS(args=dict(pubview_table=pubview_table),
                               code="""
        var selecteddata = cb_obj.selected["1d"].indices
        var count = 0
        var s1 = cb_obj.get('data');
        var d2 = pubview_table.get('source').get('data');
        d2.index = []
        d2.authors = []
        d2.titles = []
        d2.journals = []
        d2.dates = []
        d2.pmccites = []
        d2.PMID = []
        for(k = 0; k < selecteddata.length; k++){
            tind = selecteddata[k]
            d2.index.push(count)
            d2.authors.push(s1.authors[tind])
            d2.titles.push(s1.titles[tind])
            d2.journals.push(s1.journals[tind])
            d2.dates.push(s1.dates[tind])
            d2.pmccites.push(parseInt(s1.pmccites[tind]))
            d2.PMID.push(s1.PMID[tind])
            count += 1
        }
        console.log(d2)
        pubview_table.trigger('change');
        """)

    pubview_source.callback = CustomJS(code="""
        var selecteddata = cb_obj.selected["1d"].indices
        var s1 = cb_obj.get('data');
        var url = "https://www.ncbi.nlm.nih.gov/pubmed/"+s1.PMID[selecteddata[0]]
        window.open(url,'_blank');

    """)
    #######END TABLE DISPLAY CODE#####

    #####max-width IS IMPORTANT FOR PROPER WRAPPING OF TEXT
    hover = HoverTool(tooltips="""
            <div>
                <div style="max-width: 400px;">
                    <span style="font-size: 12px; font-weight: bold;">@titles</span>
                </div>
                <div style="max-width: 400px;">
                    <span style="font-size: 12px; color: #966;">@authors</span>
                <div>
                <div style="max-width: 400px;">
                    <span style="font-size: 12px; font-style: italic;">@journals, @dates</span>
                <div style="max-width: 400px;">
                    <span style="font-size: 10px;">PMID</span>
                    <span style="font-size: 10px; color: #696;">@PMID</span>
                </div>
                <div style="max-width: 400px;">
                    <span style="font-size: 10px;">PMC Citations</span>
                    <span style="font-size: 10px; color: #696;">@pmccites</span>
                </div>
            </div>
            """)

    resetCallback = CustomJS(args=dict(source=source),
                             code="""
        var data = source.get('data')
        var titles = data['titles']
        for (i=0; i < titles.length; i++) {
            data.colors[i]=data.colorsperm[i]
            data.alphas[i]= 1
        }
        source.trigger('change')
    """)
    #move function from reset callback to below so stuff updates automatically on textbox change
    textCallback = CustomJS(args=dict(source=source),
                            code="""
        var data = source.get('data')
        var value = cb_obj.get('value')
        var words = value.split(" ")
        for (i=0; i < data.titles.length; i++) {
            data.alphas[i]= 0.3
            data.colors[i]=data.colorsperm[i]
        }
        for (i=0; i < data.titles.length; i++) {
            for(j=0; j < words.length; j++){
                if (data.titles[i].toLowerCase().indexOf(words[j].toLowerCase()) !== -1) {
                    if(j == words.length-1){
                        data.colors[i]='orange'
                        data.alphas[i]= 1
                    }
                }else if(data.authors[i].toLowerCase().indexOf(words[j].toLowerCase()) !== -1){
                    if(j == words.length-1){
                        data.colors[i]='orange'
                        data.alphas[i]= 1
                    }
                }else if(data.journals[i].toLowerCase().indexOf(words[j].toLowerCase()) !== -1){
                    if(j == words.length-1){
                        data.colors[i]='orange'
                        data.alphas[i]= 1
                    }
                }else{
                    break
                }
            }
        }
        source.trigger('change')
    """)

    publistcallback = CustomJS(args=dict(pubview_table=pubview_table),
                               code="""
        var pmids = pubview_table.get('source').get('data').PMID;
        var pmidlist = pmids.join()
        var url = "https://www.ncbi.nlm.nih.gov/pubmed/"+pmidlist
        window.open(url,'_blank');
    """)

    TOOLS = 'pan,lasso_select,wheel_zoom,tap,reset'
    p = figure(plot_width=900,
               plot_height=600,
               title="'" + si + "' tSNE similarity",
               tools=[TOOLS, hover],
               active_scroll='wheel_zoom',
               active_drag="lasso_select")

    p.circle('x',
             'y',
             fill_color='colors',
             fill_alpha='alphas',
             size='ptsizes',
             line_color="#000000",
             line_alpha=0.2,
             source=source)

    #word labeles for plots
    wordsources = []
    for idx in list(range(len(d['topwords']))):
        wordsources.append(
            ColumnDataSource(
                dict(x=d['kcenters'][idx][:, 0],
                     y=d['kcenters'][idx][:, 1],
                     text=d['topwords'][idx])))

    wordglyph = Text(x="x",
                     y="y",
                     text="text",
                     text_color="#000000",
                     text_font_style="bold",
                     text_font_size="14pt")
    #5 is used for initial slider set below
    initialclust = 5
    wordholdsource = ColumnDataSource(
        dict(x=d['kcenters'][initialclust][:, 0],
             y=d['kcenters'][initialclust][:, 1],
             text=d['topwords'][initialclust]))
    p.add_glyph(wordholdsource, wordglyph)

    # source = ColumnDataSource(data=dict(x=x, y=y))
    #
    # plot = Figure(plot_width=400, plot_height=400)
    # plot.line('x', 'y', source=source, line_width=3, line_alpha=0.6)
    args = {}
    args["wordholdsource"] = wordholdsource
    for idx in list(range(len(d['topwords']))):
        args["wordsource" + str(idx + d['minc'])] = wordsources[idx]

    #had to use eval hack because of limitations on the type of objects that can be passed into the callback, limited by bokeh backend
    slidercallback = CustomJS(args=args,
                              code="""
        var f = cb_obj.value
        var ndata = eval('wordsource' + f.toString()).data;
        wordholdsource.data.x = ndata.x
        wordholdsource.data.y = ndata.y
        wordholdsource.data.text = ndata.text
        wordholdsource.trigger('change');
    """)

    wslider = Slider(start=d['minc'],
                     end=d['maxc'],
                     value=d['minc'] + initialclust,
                     step=1,
                     title="# of labels")
    # slider = Slider(start=0.1, end=4, value=1, step=.1, title="power", callback=callback)
    wslider.js_on_change('value', slidercallback)

    #formatting plot
    p.xaxis.axis_label = "Hover to view publication info, Click to open Pubmed link"
    p.xaxis.major_tick_line_color = None  # turn off x-axis major ticks
    p.xaxis.minor_tick_line_color = None  # turn off x-axis minor ticks
    p.yaxis.major_tick_line_color = None  # turn off y-axis major ticks
    p.yaxis.minor_tick_line_color = None
    p.xaxis.major_label_text_font_size = '0pt'  # turn off x-axis tick labels
    p.yaxis.major_label_text_font_size = '0pt'
    left, right, bottom, top = np.amin(d['Y'][:, 0]) * 1.1, np.amax(
        d['Y'][:, 0]) * 1.1, np.amin(d['Y'][:, 1]) * 1.1, np.amax(
            d['Y'][:, 1]) * 1.1
    p.x_range = Range1d(left, right)
    p.y_range = Range1d(bottom, top)

    #tap tool callback
    url = "https://www.ncbi.nlm.nih.gov/pubmed/@PMID"
    taptool = p.select(type=TapTool)
    taptool.callback = OpenURL(url=url)

    #work input callback
    word_input = TextInput(title="Search for term(s) in graph",
                           placeholder="Enter term to highlight",
                           callback=textCallback)
    reset = Button(label="Clear Highlighting",
                   callback=resetCallback,
                   width=150)

    spdiv = Div(text="&nbsp;", width=100, height=20)

    #add addition message saying data was cutoff in case of vague search terms
    if (len(d['rids']) == d['maxids']):
        cutoff_message = "<br>(Search was truncated to " + str(
            d['maxids']) + " newest articles due to memory constraints)"
    else:
        cutoff_message = ""

    tit1 = Div(text="<h1>" + si +
               " similarity plot</h1><br><h5>Displaying the " +
               str(d['sdfc_len']) + "/" + str(len(d['rids'])) +
               " articles that have data on pubmed" + cutoff_message + "</h5>",
               width=930)

    pubmed_list_button = Button(label="Export selected publications to pubmed",
                                callback=publistcallback)
    lt = layout([[tit1], [word_input], [reset, spdiv, wslider], [p],
                 [pubmed_list_button], [pubview_table]])

    return lt
def plot_cross_section_bokeh(filename, map_data_all_slices, map_depth_all_slices, \
                             color_range_all_slices, cross_data, boundary_data, \
                             style_parameter):
    '''
    Plot shear velocity maps and cross-sections using bokeh

    Input:
        filename is the filename of the resulting html file
        map_data_all_slices contains the velocity model parameters saved for map view plots
        map_depth_all_slices is a list of depths
        color_range_all_slices is a list of color ranges
        profile_data_all is a list of velocity profiles
        cross_lat_data_all is a list of cross-sections along latitude
        lat_value_all is a list of corresponding latitudes for these cross-sections
        cross_lon_data_all is a list of cross-sections along longitude
        lon_value_all is a list of corresponding longitudes for these cross-sections
        boundary_data is a list of boundaries
        style_parameter contains parameters to customize the plots

    Output:
        None
    
    '''
    xlabel_fontsize = style_parameter['xlabel_fontsize']
    #
    colorbar_data_all_left = []
    colorbar_data_all_right = []
    map_view_ndepth = style_parameter['map_view_ndepth']
    palette_r = palette[::-1]
    ncolor = len(palette_r)
    colorbar_top = [0.1 for i in range(ncolor)]
    colorbar_bottom = [0 for i in range(ncolor)]
    map_data_all_slices_depth = []
    for idepth in range(map_view_ndepth):
        color_min = color_range_all_slices[idepth][0]
        color_max = color_range_all_slices[idepth][1]
        color_step = (color_max - color_min) * 1. / ncolor
        colorbar_left = np.linspace(color_min, color_max - color_step, ncolor)
        colorbar_right = np.linspace(color_min + color_step, color_max, ncolor)
        colorbar_data_all_left.append(colorbar_left)
        colorbar_data_all_right.append(colorbar_right)
        map_depth = map_depth_all_slices[idepth]
        map_data_all_slices_depth.append(
            'Depth: {0:8.0f} km'.format(map_depth))
    # data for the colorbar
    colorbar_data_one_slice = {}
    colorbar_data_one_slice['colorbar_left'] = colorbar_data_all_left[
        style_parameter['map_view_default_index']]
    colorbar_data_one_slice['colorbar_right'] = colorbar_data_all_right[
        style_parameter['map_view_default_index']]
    colorbar_data_one_slice_bokeh = ColumnDataSource(data=dict(colorbar_top=colorbar_top,colorbar_bottom=colorbar_bottom,\
                                                               colorbar_left=colorbar_data_one_slice['colorbar_left'],\
                                                               colorbar_right=colorbar_data_one_slice['colorbar_right'],\
                                                               palette_r=palette_r))
    colorbar_data_all_slices_bokeh = ColumnDataSource(data=dict(colorbar_data_all_left=colorbar_data_all_left,\
                                                                colorbar_data_all_right=colorbar_data_all_right))
    #
    map_view_label_lon = style_parameter['map_view_depth_label_lon']
    map_view_label_lat = style_parameter['map_view_depth_label_lat']
    map_data_one_slice_depth = map_data_all_slices_depth[
        style_parameter['map_view_default_index']]
    map_data_one_slice_depth_bokeh = ColumnDataSource(
        data=dict(lat=[map_view_label_lat],
                  lon=[map_view_label_lon],
                  map_depth=[map_data_one_slice_depth]))

    #
    map_view_default_index = style_parameter['map_view_default_index']
    #map_data_one_slice = map_data_all_slices[map_view_default_index]
    map_color_all_slices = []
    for i in range(len(map_data_all_slices)):
        vmin, vmax = color_range_all_slices[i]
        map_color = val_to_rgb(map_data_all_slices[i], palette_r, vmin, vmax)
        map_color_2d = map_color.view('uint32').reshape(map_color.shape[:2])
        map_color_all_slices.append(map_color_2d)
    map_color_one_slice = map_color_all_slices[map_view_default_index]
    #
    map_data_one_slice_bokeh = ColumnDataSource(data=dict(x=[style_parameter['map_view_image_lon_min']],\
                   y=[style_parameter['map_view_image_lat_min']],dw=[style_parameter['nlon']],\
                   dh=[style_parameter['nlat']],map_data_one_slice=[map_color_one_slice]))
    map_data_all_slices_bokeh = ColumnDataSource(data=dict(map_data_all_slices=map_color_all_slices,\
                                                           map_data_all_slices_depth=map_data_all_slices_depth))
    #

    plot_depth = np.shape(cross_data)[0] * style_parameter['cross_ddepth']
    plot_lon = great_arc_distance(style_parameter['cross_default_lat0'], style_parameter['cross_default_lon0'],\
                                  style_parameter['cross_default_lat1'], style_parameter['cross_default_lon1'])

    vs_min = style_parameter['cross_view_vs_min']
    vs_max = style_parameter['cross_view_vs_max']
    cross_color = val_to_rgb(cross_data, palette_r, vs_min, vs_max)
    cross_color_2d = cross_color.view('uint32').reshape(cross_color.shape[:2])
    cross_data_bokeh = ColumnDataSource(data=dict(x=[0],\
                   y=[plot_depth],dw=[plot_lon],\
                   dh=[plot_depth],cross_data=[cross_color_2d]))

    map_line_bokeh = ColumnDataSource(data=dict(lat=[style_parameter['cross_default_lat0'], style_parameter['cross_default_lat1']],\
                                                    lon=[style_parameter['cross_default_lon0'], style_parameter['cross_default_lon1']]))
    #
    ncolor_cross = len(my_palette)
    colorbar_top_cross = [0.1 for i in range(ncolor_cross)]
    colorbar_bottom_cross = [0 for i in range(ncolor_cross)]
    color_min_cross = style_parameter['cross_view_vs_min']
    color_max_cross = style_parameter['cross_view_vs_max']
    color_step_cross = (color_max_cross - color_min_cross) * 1. / ncolor_cross
    colorbar_left_cross = np.linspace(color_min_cross,
                                      color_max_cross - color_step_cross,
                                      ncolor_cross)
    colorbar_right_cross = np.linspace(color_min_cross + color_step_cross,
                                       color_max_cross, ncolor_cross)
    # ==============================
    map_view = Figure(plot_width=style_parameter['map_view_plot_width'], plot_height=style_parameter['map_view_plot_height'], \
                      tools=style_parameter['map_view_tools'], title=style_parameter['map_view_title'], \
                      y_range=[style_parameter['map_view_figure_lat_min'], style_parameter['map_view_figure_lat_max']],\
                      x_range=[style_parameter['map_view_figure_lon_min'], style_parameter['map_view_figure_lon_max']])
    #
    map_view.image_rgba('map_data_one_slice',x='x',\
                   y='y',dw='dw',dh='dh',\
                   source=map_data_one_slice_bokeh, level='image')

    depth_slider_callback = CustomJS(args=dict(map_data_one_slice_bokeh=map_data_one_slice_bokeh,\
                                               map_data_all_slices_bokeh=map_data_all_slices_bokeh,\
                                               colorbar_data_all_slices_bokeh=colorbar_data_all_slices_bokeh,\
                                               colorbar_data_one_slice_bokeh=colorbar_data_one_slice_bokeh,\
                                               map_data_one_slice_depth_bokeh=map_data_one_slice_depth_bokeh), code="""

        var d_index = Math.round(cb_obj.value)
        
        var map_data_all_slices = map_data_all_slices_bokeh.data
        
        map_data_one_slice_bokeh.data['map_data_one_slice'] = [map_data_all_slices['map_data_all_slices'][d_index]]
        map_data_one_slice_bokeh.change.emit()
        
        var color_data_all_slices = colorbar_data_all_slices_bokeh.data
        colorbar_data_one_slice_bokeh.data['colorbar_left'] = color_data_all_slices['colorbar_data_all_left'][d_index]
        colorbar_data_one_slice_bokeh.data['colorbar_right'] = color_data_all_slices['colorbar_data_all_right'][d_index]
        colorbar_data_one_slice_bokeh.change.emit()
        
        map_data_one_slice_depth_bokeh.data['map_depth'] = [map_data_all_slices['map_data_all_slices_depth'][d_index]]
        map_data_one_slice_depth_bokeh.change.emit()
        
    """)
    depth_slider = Slider(start=0, end=style_parameter['map_view_ndepth']-1, \
                          value=map_view_default_index, step=1, \
                          width=style_parameter['map_view_plot_width'],\
                          title=style_parameter['depth_slider_title'], height=50)
    depth_slider.js_on_change('value', depth_slider_callback)
    depth_slider_callback.args["depth_index"] = depth_slider
    # ------------------------------
    # add boundaries to map view
    # country boundaries
    map_view.multi_line(boundary_data['country']['longitude'],\
                        boundary_data['country']['latitude'],color='black',\
                        line_width=2, level='underlay',nonselection_line_alpha=1.0,\
                        nonselection_line_color='black')
    # marine boundaries
    map_view.multi_line(boundary_data['marine']['longitude'],\
                        boundary_data['marine']['latitude'],color='black',\
                        level='underlay',nonselection_line_alpha=1.0,\
                        nonselection_line_color='black')
    # shoreline boundaries
    map_view.multi_line(boundary_data['shoreline']['longitude'],\
                        boundary_data['shoreline']['latitude'],color='black',\
                        line_width=2, level='underlay',nonselection_line_alpha=1.0,\
                        nonselection_line_color='black')
    # state boundaries
    map_view.multi_line(boundary_data['state']['longitude'],\
                        boundary_data['state']['latitude'],color='black',\
                        level='underlay',nonselection_line_alpha=1.0,\
                        nonselection_line_color='black')
    # ------------------------------
    # add depth label
    map_view.rect(style_parameter['map_view_depth_box_lon'], style_parameter['map_view_depth_box_lat'], \
                  width=style_parameter['map_view_depth_box_width'], height=style_parameter['map_view_depth_box_height'], \
                  width_units='screen',height_units='screen', color='#FFFFFF', line_width=1., line_color='black', level='underlay')
    map_view.text('lon', 'lat', 'map_depth', source=map_data_one_slice_depth_bokeh,\
                  text_font_size=style_parameter['annotating_text_font_size'],text_align='left',level='underlay')
    # ------------------------------
    map_view.line('lon', 'lat', source=map_line_bokeh, line_dash=[8,2,8,2], line_color='#00ff00',\
                        nonselection_line_alpha=1.0, line_width=5.,\
                        nonselection_line_color='black')
    map_view.text([style_parameter['cross_default_lon0']],[style_parameter['cross_default_lat0']], ['A'], \
            text_font_size=style_parameter['title_font_size'],text_align='left')
    map_view.text([style_parameter['cross_default_lon1']],[style_parameter['cross_default_lat1']], ['B'], \
            text_font_size=style_parameter['title_font_size'],text_align='left')
    # ------------------------------
    # change style
    map_view.title.text_font_size = style_parameter['title_font_size']
    map_view.title.align = 'center'
    map_view.title.text_font_style = 'normal'
    map_view.xaxis.axis_label = style_parameter['map_view_xlabel']
    map_view.xaxis.axis_label_text_font_style = 'normal'
    map_view.xaxis.axis_label_text_font_size = xlabel_fontsize
    map_view.xaxis.major_label_text_font_size = xlabel_fontsize
    map_view.yaxis.axis_label = style_parameter['map_view_ylabel']
    map_view.yaxis.axis_label_text_font_style = 'normal'
    map_view.yaxis.axis_label_text_font_size = xlabel_fontsize
    map_view.yaxis.major_label_text_font_size = xlabel_fontsize
    map_view.xgrid.grid_line_color = None
    map_view.ygrid.grid_line_color = None
    map_view.toolbar.logo = None
    map_view.toolbar_location = 'above'
    map_view.toolbar_sticky = False
    # ==============================
    # plot colorbar

    colorbar_fig = Figure(tools=[], y_range=(0,0.1),plot_width=style_parameter['map_view_plot_width'], \
                      plot_height=style_parameter['colorbar_plot_height'],title=style_parameter['colorbar_title'])
    colorbar_fig.toolbar_location = None
    colorbar_fig.quad(top='colorbar_top',bottom='colorbar_bottom',left='colorbar_left',right='colorbar_right',\
                  color='palette_r',source=colorbar_data_one_slice_bokeh)
    colorbar_fig.yaxis[0].ticker = FixedTicker(ticks=[])
    colorbar_fig.xgrid.grid_line_color = None
    colorbar_fig.ygrid.grid_line_color = None
    colorbar_fig.xaxis.axis_label_text_font_size = xlabel_fontsize
    colorbar_fig.xaxis.major_label_text_font_size = xlabel_fontsize
    colorbar_fig.xaxis[0].formatter = PrintfTickFormatter(format="%5.2f")
    colorbar_fig.title.text_font_size = xlabel_fontsize
    colorbar_fig.title.align = 'center'
    colorbar_fig.title.text_font_style = 'normal'
    # ==============================
    # annotating text
    annotating_fig01 = Div(text=style_parameter['annotating_html01'], \
        width=style_parameter['annotation_plot_width'], height=style_parameter['annotation_plot_height'])
    annotating_fig02 = Div(text="""<p style="font-size:16px">""", \
        width=style_parameter['annotation_plot_width'], height=style_parameter['annotation_plot_height'])
    # ==============================
    # plot cross-section along latitude
    cross_section_plot_width = int(style_parameter['cross_plot_height'] * 1.0 /
                                   plot_depth * plot_lon / 10.)
    cross_view = Figure(plot_width=cross_section_plot_width, plot_height=style_parameter['cross_plot_height'], \
                      tools=style_parameter['cross_view_tools'], title=style_parameter['cross_view_title'], \
                      y_range=[plot_depth, -30],\
                      x_range=[0, plot_lon])
    cross_view.image_rgba('cross_data',x='x',\
                   y='y',dw='dw',dh='dh',\
                   source=cross_data_bokeh, level='image')
    cross_view.text([plot_lon*0.1], [-10], ['A'], \
                text_font_size=style_parameter['title_font_size'],text_align='left',level='underlay')
    cross_view.text([plot_lon*0.9], [-10], ['B'], \
                text_font_size=style_parameter['title_font_size'],text_align='left',level='underlay')
    # ------------------------------
    # change style
    cross_view.title.text_font_size = style_parameter['title_font_size']
    cross_view.title.align = 'center'
    cross_view.title.text_font_style = 'normal'
    cross_view.xaxis.axis_label = style_parameter['cross_view_xlabel']
    cross_view.xaxis.axis_label_text_font_style = 'normal'
    cross_view.xaxis.axis_label_text_font_size = xlabel_fontsize
    cross_view.xaxis.major_label_text_font_size = xlabel_fontsize
    cross_view.yaxis.axis_label = style_parameter['cross_view_ylabel']
    cross_view.yaxis.axis_label_text_font_style = 'normal'
    cross_view.yaxis.axis_label_text_font_size = xlabel_fontsize
    cross_view.yaxis.major_label_text_font_size = xlabel_fontsize
    cross_view.xgrid.grid_line_color = None
    cross_view.ygrid.grid_line_color = None
    cross_view.toolbar.logo = None
    cross_view.toolbar_location = 'right'
    cross_view.toolbar_sticky = False
    # ==============================
    colorbar_fig_right = Figure(tools=[], y_range=(0,0.1),plot_width=cross_section_plot_width, \
                      plot_height=style_parameter['colorbar_plot_height'],title=style_parameter['colorbar_title'])
    colorbar_fig_right.toolbar_location = None

    colorbar_fig_right.quad(top=colorbar_top_cross,bottom=colorbar_bottom_cross,\
                            left=colorbar_left_cross,right=colorbar_right_cross,\
                            color=my_palette)
    colorbar_fig_right.yaxis[0].ticker = FixedTicker(ticks=[])
    colorbar_fig_right.xgrid.grid_line_color = None
    colorbar_fig_right.ygrid.grid_line_color = None
    colorbar_fig_right.xaxis.axis_label_text_font_size = xlabel_fontsize
    colorbar_fig_right.xaxis.major_label_text_font_size = xlabel_fontsize
    colorbar_fig_right.xaxis[0].formatter = PrintfTickFormatter(format="%5.2f")
    colorbar_fig_right.title.text_font_size = xlabel_fontsize
    colorbar_fig_right.title.align = 'center'
    colorbar_fig_right.title.text_font_style = 'normal'
    # ==============================
    output_file(filename,
                title=style_parameter['html_title'],
                mode=style_parameter['library_source'])
    left_column = Column(depth_slider,
                         map_view,
                         colorbar_fig,
                         annotating_fig01,
                         width=style_parameter['left_column_width'])

    right_column = Column(annotating_fig02,
                          cross_view,
                          colorbar_fig_right,
                          width=cross_section_plot_width)
    layout = Row(left_column, right_column, height=800)
    save(layout)
# Done! Not too bad. Now let's define the arguments for the callback function

# Define arguments for JavaScript callback function
cb_args = {
    'source': source,
    'muAaSlider': mu_Aa,
    'muaASlider': mu_aA,
    'xoSlider': x_init
}
# Asign arguments to function
cb = CustomJS(args=cb_args, code=cb_script)

# Now we must assign this callback function to each of the sliders. What this means is that we must indicate that every time the slider value is changed, the `JavaScript` callback function must be executed.

# Assign callback function to widgets
x_init.js_on_change('value', cb)
mu_Aa.js_on_change('value', cb)
mu_aA.js_on_change('value', cb)

# Alright. Now everything is setup for our interactive plot! Now we just need to define the bokeh plot.

# Define bokeh axis
x_allele_ax = bokeh.plotting.figure(width=300,
                                    height=275,
                                    x_axis_label='time (a.u.)',
                                    y_axis_label='allele frequency',
                                    y_range=[-0.05, 1.05])

# Populate the plot with our line coming from the Data Source
x_allele_ax.line(x='time', y='x_allele', line_width=2, source=source)
Exemple #8
0
def cluster_gui(doc):
    global s2, s1, old_indices
    old_indices = []
    output_file("tile.html")
    tile_provider = get_provider(CARTODBPOSITRON)
    x = []
    y = []
    name = []
    global fig03
    fig03 = figure(
        plot_width=400,
        plot_height=400,
        tools=["box_zoom", "wheel_zoom", "reset", "save"],
        title="Waveforms from current selection",
    )
    for i, st in enumerate(stations):
        xm, ym = merc(st.lat, st.lon)
        x.append(xm)
        y.append(ym)
        name.append(st.nsl_string())

    # create first subplot
    plot_width = 400
    plot_height = 400
    d = num.ones_like(x)

    s1 = ColumnDataSource(data=dict(x=x, y=y, ind=d, name=name))
    # range bounds supplied in web mercator coordinates
    fig01 = figure(
        x_axis_type="mercator",
        y_axis_type="mercator",
        plot_width=plot_width,
        plot_height=plot_height,
        tools=[
            "lasso_select", "box_select", "reset", "save", "box_zoom",
            "wheel_zoom"
        ],
        title="Select",
    )
    fig01.add_tile(tile_provider)

    fig01.scatter("x", "y", source=s1, alpha=0.6, size=8)

    # create second subplot
    s2 = ColumnDataSource(data=dict(x=[], y=[], ind=[], name=[]))

    color_mapper = LinearColorMapper(palette='Magma256', low=1, high=100)

    fig02 = figure(
        x_axis_type="mercator",
        y_axis_type="mercator",
        plot_width=plot_width,
        plot_height=plot_height,
        x_range=(num.min(x), num.max(x)),
        y_range=(num.min(y), num.max(y)),
        tools=["box_zoom", "wheel_zoom", "reset", "save"],
        title="Stations selected for Array",
    )
    fig02.add_tile(tile_provider)

    fig02.scatter("x",
                  "y",
                  source=s2,
                  alpha=1,
                  color={
                      'field': 'ind',
                      'transform': color_mapper
                  },
                  size=8)

    x_event, y_event = merc(event.lat, event.lon)
    fig01.scatter(x_event, y_event, size=8, color="red")
    fig02.scatter(x_event, y_event, size=8, color="red")

    columns = [
        TableColumn(field="x", title="X axis"),
        TableColumn(field="y", title="Y axis"),
        TableColumn(field="ind", title="indices"),
        TableColumn(field="name", title="name"),
    ]

    table = DataTable(
        source=s2,
        columns=columns,
        width=400,
        height=600,
        sortable=True,
        selectable=True,
        editable=True,
    )

    source_count = 0
    callback_slider = CustomJS(code="""
        source_count = slider.value;
        """)
    global slider

    slider = Slider(start=1, end=100, value=1, step=1, title="Array number")
    slider.js_on_change('value', callback_slider)

    s1.selected.js_on_change(
        "indices",
        CustomJS(args=dict(s1=s1, s2=s2, s3=slider, table=table),
                 code="""
            var inds = cb_obj.indices;
            var d1 = s1.data;
            var d2 = s2.data;
            const A = s3.value;

            for (var i = 0; i < inds.length; i++) {
                d2['x'].push(d1['x'][inds[i]])
                d2['y'].push(d1['y'][inds[i]])
                d2['name'].push(d1['name'][inds[i]])
                d2['ind'].push(A)
            }
            s2.change.emit();
            table.change.emit();
    	    s2.data = s2.data;

            var inds = source_data.selected.indices;
            var data = source_data.data;
            var out = "name, x, y, ind\\n";
            for (i = 0; i < inds.length; i++) {
                out += data['name'][inds[i]] + "," + data['x'][inds[i]] + "," + data['y'][inds[i]] + "," + data['ind'][inds[i]]   + "\\n";
            }
            var file = new Blob([out], {type: 'text/plain'});

        """),
    ),

    savebutton = Button(label="Save", button_type="success")
    savebutton.callback = CustomJS(
        args=dict(source_data=s1),
        code="""
            var inds = source_data.selected.indices;
            var data = source_data.data;
            var out = "name, x, y, ind\\n";
            for (i = 0; i < inds.length; i++) {
                out += data['name'][inds[i]] + "," + data['x'][inds[i]] + "," + data['y'][inds[i]] + "," + data['ind'][inds[i]]   + "\\n";
            }
            var file = new Blob([out], {type: 'text/plain'});
            var elem = window.document.createElement('a');
            elem.href = window.URL.createObjectURL(file);
            elem.download = 'arrays.txt';
            document.body.appendChild(elem);
            elem.click();
            document.body.removeChild(elem);
            """,
    )

    tooltips = [
        ("X:", "@x"),
        ("Y:", "@y"),
        ("Array:", "@ind"),
        ("Station:", "@name"),
    ]

    fig01.add_tools(HoverTool(tooltips=tooltips))
    fig02.add_tools(HoverTool(tooltips=tooltips))
    fig03.add_tools(HoverTool(tooltips=tooltips))

    endbutton = Button(label="End and proceed", button_type="success")
    endbutton.on_click(button_callback)

    clearbutton = Button(label="Clear all", button_type="success")
    clearbutton.on_click(clearbuttonlast_callback)
    clearbuttonlast = Button(label="Clear last selection",
                             button_type="success")
    clearbuttonlast.on_click(clearbuttonlast_callback)
    clearbuttonone = Button(label="Remove one from list",
                            button_type="success")
    clearbuttonone.on_click(clearbuttonone_callback)

    b = Button(label="Reset all plots")
    b.js_on_click(
        CustomJS(code="""\
document.querySelectorAll('.bk-tool-icon-reset[title="Reset"]').forEach(d => d.click())
"""))
    #layout = grid([fig01, fig02, table, fig03, slider, clearbuttonlast, clearbutton, savebutton, endbutton], ncols=3, nrows=4)
    global text_input
    text_input = TextInput(value="1", title="Array number:")

    buttons = column(clearbuttonlast, clearbuttonone, clearbutton, savebutton,
                     endbutton)
    inputs = column(text_input, slider)

    layout_grid = layout([fig01, fig02, buttons], [fig03, inputs, table])

    #curdoc().add_root(layout)
    cluster_result = []
    doc.add_root(layout_grid)
    #global session
    #session = push_session(curdoc())

    #curdoc().add_periodic_callback(update, 100)
    #session.show(layout)
    doc.add_periodic_callback(update, 900)
    curdoc().title = "Array selection"
    'CDS': CDSimages['weighted'],
    'NSA': NSA_weighted
},
                           code="""
        FFidx = Math.round((cb_obj.value - cb_obj.start ) / cb_obj.step)
        CDS[0].data.imageData = [NSA.map(PF => PF.map(f => f[FFidx][0])).flat()]
        CDS[1].data.imageData = [NSA.map(PF => PF.map(f => f[FFidx][1])).flat()]
        CDS[0].change.emit();
        CDS[1].change.emit();
        """)
sliderW = Slider(start=0,
                 end=1,
                 value=0,
                 step=1.0 / (nFFs - 1),
                 title="Fat fraction")
sliderW.js_on_change('value', sliderCallbackW)

sliderCallbackUW = CustomJS(args={
    'CDS': CDSimages['unweighted'],
    'NSA': NSA_equalWeights
},
                            code="""
        FFidx = Math.round((cb_obj.value - cb_obj.start ) / cb_obj.step)
        CDS[0].data.imageData = [NSA.map(PF => PF.map(f => f[FFidx][0])).flat()]
        CDS[1].data.imageData = [NSA.map(PF => PF.map(f => f[FFidx][1])).flat()]
        CDS[0].change.emit();
        CDS[1].change.emit();
        """)
sliderUW = Slider(start=0,
                  end=1,
                  value=0,
def plot_dispersion_bokeh(filename, period_array, curve_data_array,
                          boundary_data, style_parameter):
    '''
    Plot dispersion maps and curves using bokeh
    
    Input:
        filename is the filename of the resulting html file
        period_array is a list of period
        curve_data_array is a list of dispersion curves
        boundary_data is a list of boundaries
        style_parameter contains plotting parameters 
    
    Output:
        None
        
    '''
    xlabel_fontsize = style_parameter['xlabel_fontsize']
    # ==============================
    # prepare data
    map_data_all_slices_velocity = []
    map_data_all_slices_period = []
    map_data_all_slices_color = []
    colorbar_data_all_left = []
    colorbar_data_all_right = []
    nperiod = len(period_array)
    ncurve = len(curve_data_array)
    ncolor = len(palette)
    palette_r = palette[::-1]
    colorbar_top = [0.1 for i in range(ncolor)]
    colorbar_bottom = [0 for i in range(ncolor)]
    for iperiod in range(nperiod):
        one_slice_lat_list = []
        one_slice_lon_list = []
        one_slice_vel_list = []

        map_period = period_array[iperiod]
        for icurve in range(ncurve):
            acurve = curve_data_array[icurve]
            curve_lat = acurve['latitude']
            curve_lon = acurve['longitude']
            curve_vel = acurve['velocity']
            curve_period = acurve['period']
            one_slice_lat_list.append(curve_lat)
            one_slice_lon_list.append(curve_lon)
            if map_period in curve_period:
                curve_period_index = curve_period.index(map_period)
                one_slice_vel_list.append(curve_vel[curve_period_index])
            else:
                one_slice_vel_list.append(style_parameter['nan_value'])
        # get color for dispersion values
        one_slice_vel_mean = np.nanmean(one_slice_vel_list)
        one_slice_vel_std = np.nanstd(one_slice_vel_list)

        color_min = one_slice_vel_mean - one_slice_vel_std * style_parameter[
            'spread_factor']
        color_max = one_slice_vel_mean + one_slice_vel_std * style_parameter[
            'spread_factor']
        color_step = (color_max - color_min) * 1. / ncolor
        one_slice_color_list = get_color_list(one_slice_vel_list,color_min,color_max,palette_r,\
                                             style_parameter['nan_value'],style_parameter['nan_color'])
        colorbar_left = np.linspace(color_min, color_max - color_step, ncolor)
        colorbar_right = np.linspace(color_min + color_step, color_max, ncolor)
        if one_slice_lat_list:
            map_data_all_slices_velocity.append(one_slice_vel_list)
            map_data_all_slices_period.append(
                'Period: {0:6.1f} s'.format(map_period))
            map_data_all_slices_color.append(one_slice_color_list)
            colorbar_data_all_left.append(colorbar_left)
            colorbar_data_all_right.append(colorbar_right)
    # get location for all points
    map_lat_list, map_lon_list = [], []
    map_lat_label_list, map_lon_label_list = [], []
    for i in range(ncurve):
        acurve = curve_data_array[i]
        map_lat_list.append(acurve['latitude'])
        map_lon_list.append(acurve['longitude'])
        map_lat_label_list.append('Lat: {0:12.3f}'.format(acurve['latitude']))
        map_lon_label_list.append('Lon: {0:12.3f}'.format(acurve['longitude']))
    # data for the map view plot
    map_view_label_lon = style_parameter['map_view_period_label_lon']
    map_view_label_lat = style_parameter['map_view_period_label_lat']

    map_data_one_slice = map_data_all_slices_color[
        style_parameter['map_view_default_index']]
    map_data_one_slice_period = map_data_all_slices_period[
        style_parameter['map_view_default_index']]
    map_data_one_slice_bokeh = ColumnDataSource(data=dict(map_lat_list=map_lat_list,\
                                                          map_lon_list=map_lon_list,\
                                                          map_data_one_slice=map_data_one_slice))
    map_data_one_slice_period_bokeh = ColumnDataSource(
        data=dict(lat=[map_view_label_lat],
                  lon=[map_view_label_lon],
                  map_period=[map_data_one_slice_period]))
    map_data_all_slices_bokeh = ColumnDataSource(data=dict(map_data_all_slices_color=map_data_all_slices_color,\
                                                          map_data_all_slices_period=map_data_all_slices_period))

    # data for the colorbar
    colorbar_data_one_slice = {}
    colorbar_data_one_slice['colorbar_left'] = colorbar_data_all_left[
        style_parameter['map_view_default_index']]
    colorbar_data_one_slice['colorbar_right'] = colorbar_data_all_right[
        style_parameter['map_view_default_index']]
    colorbar_data_one_slice_bokeh = ColumnDataSource(data=dict(colorbar_top=colorbar_top,colorbar_bottom=colorbar_bottom,
                                                               colorbar_left=colorbar_data_one_slice['colorbar_left'],\
                                                               colorbar_right=colorbar_data_one_slice['colorbar_right'],\
                                                               palette_r=palette_r))
    colorbar_data_all_slices_bokeh = ColumnDataSource(data=dict(colorbar_data_all_left=colorbar_data_all_left,\
                                                                colorbar_data_all_right=colorbar_data_all_right))
    # data for dispersion curves
    curve_default_index = style_parameter['curve_default_index']
    selected_dot_on_map_bokeh = ColumnDataSource(data=dict(lat=[map_lat_list[curve_default_index]],\
                                                     lon=[map_lon_list[curve_default_index]],\
                                                     color=[map_data_one_slice[curve_default_index]],\
                                                     index=[curve_default_index]))
    selected_curve_data = curve_data_array[curve_default_index]
    selected_curve_data_bokeh = ColumnDataSource(data=dict(curve_period=selected_curve_data['period'],\
                                                          curve_velocity=selected_curve_data['velocity']))

    period_all = []
    velocity_all = []
    for acurve in curve_data_array:
        period_all.append(acurve['period'])
        velocity_all.append(acurve['velocity'])
    curve_data_all_bokeh = ColumnDataSource(
        data=dict(period_all=period_all, velocity_all=velocity_all))

    selected_curve_lat_label_bokeh = ColumnDataSource(data=dict(x=[style_parameter['curve_lat_label_x']], \
                                                                y=[style_parameter['curve_lat_label_y']],\
                                                                lat_label=[map_lat_label_list[curve_default_index]]))
    selected_curve_lon_label_bokeh = ColumnDataSource(data=dict(x=[style_parameter['curve_lon_label_x']], \
                                                                y=[style_parameter['curve_lon_label_y']],\
                                                                lon_label=[map_lon_label_list[curve_default_index]]))
    all_curve_lat_label_bokeh = ColumnDataSource(data=dict(
        map_lat_label_list=map_lat_label_list))
    all_curve_lon_label_bokeh = ColumnDataSource(data=dict(
        map_lon_label_list=map_lon_label_list))
    # ==============================
    map_view = Figure(plot_width=style_parameter['map_view_plot_width'], \
                      plot_height=style_parameter['map_view_plot_height'], \
                      y_range=[style_parameter['map_view_lat_min'],\
                    style_parameter['map_view_lat_max']], x_range=[style_parameter['map_view_lon_min'],\
                    style_parameter['map_view_lon_max']], tools=style_parameter['map_view_tools'],\
                    title=style_parameter['map_view_title'])
    # ------------------------------
    # add boundaries to map view
    # country boundaries
    map_view.multi_line(boundary_data['country']['longitude'],\
                        boundary_data['country']['latitude'],color='black',\
                        line_width=2, level='underlay',nonselection_line_alpha=1.0,\
                        nonselection_line_color='black')
    # marine boundaries
    map_view.multi_line(boundary_data['marine']['longitude'],\
                        boundary_data['marine']['latitude'],color='black',\
                        level='underlay',nonselection_line_alpha=1.0,\
                        nonselection_line_color='black')
    # shoreline boundaries
    map_view.multi_line(boundary_data['shoreline']['longitude'],\
                        boundary_data['shoreline']['latitude'],color='black',\
                        line_width=2, level='underlay',nonselection_line_alpha=1.0,\
                        nonselection_line_color='black')
    # state boundaries
    map_view.multi_line(boundary_data['state']['longitude'],\
                        boundary_data['state']['latitude'],color='black',\
                        level='underlay',nonselection_line_alpha=1.0,\
                        nonselection_line_color='black')
    # ------------------------------
    # add period label
    map_view.rect(style_parameter['map_view_period_box_lon'], style_parameter['map_view_period_box_lat'], \
                  width=style_parameter['map_view_period_box_width'], height=style_parameter['map_view_period_box_height'], \
                  width_units='screen',height_units='screen', color='#FFFFFF', line_width=1., line_color='black', level='underlay')
    map_view.text('lon', 'lat', 'map_period', source=map_data_one_slice_period_bokeh,\
                  text_font_size=style_parameter['annotating_text_font_size'],text_align='left',level='underlay')
    # ------------------------------
    # plot dots
    map_view.circle('map_lon_list', 'map_lat_list', color='map_data_one_slice', \
                    source=map_data_one_slice_bokeh, size=style_parameter['marker_size'],\
                    line_width=0.2, line_color='black', alpha=1.0,\
                    selection_color='map_data_one_slice', selection_line_color='black',\
                    selection_fill_alpha=1.0,\
                    nonselection_fill_alpha=1.0, nonselection_fill_color='map_data_one_slice',\
                    nonselection_line_color='black', nonselection_line_alpha=1.0)
    map_view.circle('lon', 'lat', color='color', source=selected_dot_on_map_bokeh, \
                    line_color='#00ff00', line_width=4.0, alpha=1.0, \
                    size=style_parameter['selected_marker_size'])
    # ------------------------------
    # change style
    map_view.title.text_font_size = style_parameter['title_font_size']
    map_view.title.align = 'center'
    map_view.title.text_font_style = 'normal'
    map_view.xaxis.axis_label = style_parameter['map_view_xlabel']
    map_view.xaxis.axis_label_text_font_style = 'normal'
    map_view.xaxis.axis_label_text_font_size = xlabel_fontsize
    map_view.xaxis.major_label_text_font_size = xlabel_fontsize
    map_view.yaxis.axis_label = style_parameter['map_view_ylabel']
    map_view.yaxis.axis_label_text_font_style = 'normal'
    map_view.yaxis.axis_label_text_font_size = xlabel_fontsize
    map_view.yaxis.major_label_text_font_size = xlabel_fontsize
    map_view.xgrid.grid_line_color = None
    map_view.ygrid.grid_line_color = None
    map_view.toolbar.logo = None
    map_view.toolbar_location = 'above'
    map_view.toolbar_sticky = False
    # ==============================
    # plot colorbar
    colorbar_fig = Figure(tools=[], y_range=(0,0.1),plot_width=style_parameter['map_view_plot_width'], \
                          plot_height=style_parameter['colorbar_plot_height'],title=style_parameter['colorbar_title'])
    colorbar_fig.toolbar_location = None
    colorbar_fig.quad(top='colorbar_top',bottom='colorbar_bottom',left='colorbar_left',right='colorbar_right',\
                      fill_color='palette_r',source=colorbar_data_one_slice_bokeh)
    colorbar_fig.yaxis[0].ticker = FixedTicker(ticks=[])
    colorbar_fig.xgrid.grid_line_color = None
    colorbar_fig.ygrid.grid_line_color = None
    colorbar_fig.xaxis.axis_label_text_font_size = xlabel_fontsize
    colorbar_fig.xaxis.major_label_text_font_size = xlabel_fontsize
    colorbar_fig.xaxis[0].formatter = PrintfTickFormatter(format="%5.2f")
    colorbar_fig.title.text_font_size = xlabel_fontsize
    colorbar_fig.title.align = 'center'
    colorbar_fig.title.text_font_style = 'normal'
    # ==============================
    curve_fig = Figure(plot_width=style_parameter['curve_plot_width'], plot_height=style_parameter['curve_plot_height'], \
                       y_range=(style_parameter['curve_y_min'],style_parameter['curve_y_max']), \
                       x_range=(style_parameter['curve_x_min'],style_parameter['curve_x_max']),x_axis_type='log',\
                        tools=['save','box_zoom','wheel_zoom','reset','crosshair','pan'],
                        title=style_parameter['curve_title'])
    # ------------------------------
    curve_fig.rect([style_parameter['curve_label_box_x']], [style_parameter['curve_label_box_y']], \
                   width=style_parameter['curve_label_box_width'], height=style_parameter['curve_label_box_height'], \
                   width_units='screen', height_units='screen', color='#FFFFFF', line_width=1., line_color='black', level='underlay')
    curve_fig.text('x', 'y', \
                   'lat_label', source=selected_curve_lat_label_bokeh)
    curve_fig.text('x', 'y', \
                   'lon_label', source=selected_curve_lon_label_bokeh)
    # ------------------------------
    curve_fig.line('curve_period',
                   'curve_velocity',
                   source=selected_curve_data_bokeh,
                   color='black')
    curve_fig.circle('curve_period',
                     'curve_velocity',
                     source=selected_curve_data_bokeh,
                     size=5,
                     color='black')
    # ------------------------------
    curve_fig.title.text_font_size = style_parameter['title_font_size']
    curve_fig.title.align = 'center'
    curve_fig.title.text_font_style = 'normal'
    curve_fig.xaxis.axis_label = style_parameter['curve_xlabel']
    curve_fig.xaxis.axis_label_text_font_style = 'normal'
    curve_fig.xaxis.axis_label_text_font_size = xlabel_fontsize
    curve_fig.xaxis.major_label_text_font_size = xlabel_fontsize
    curve_fig.yaxis.axis_label = style_parameter['curve_ylabel']
    curve_fig.yaxis.axis_label_text_font_style = 'normal'
    curve_fig.yaxis.axis_label_text_font_size = xlabel_fontsize
    curve_fig.yaxis.major_label_text_font_size = xlabel_fontsize
    curve_fig.xgrid.grid_line_dash = [4, 2]
    curve_fig.ygrid.grid_line_dash = [4, 2]
    curve_fig.xaxis[0].formatter = PrintfTickFormatter(format="%4.0f")
    curve_fig.toolbar.logo = None
    curve_fig.toolbar_location = 'above'
    curve_fig.toolbar_sticky = False
    # ==============================
    map_data_one_slice_js = CustomJS(args=dict(selected_dot_on_map_bokeh=selected_dot_on_map_bokeh,\
                                                          map_data_one_slice_bokeh=map_data_one_slice_bokeh,\
                                                          selected_curve_data_bokeh=selected_curve_data_bokeh,\
                                                          curve_data_all_bokeh=curve_data_all_bokeh,\
                                                          selected_curve_lat_label_bokeh=selected_curve_lat_label_bokeh,\
                                                          selected_curve_lon_label_bokeh=selected_curve_lon_label_bokeh,\
                                                          all_curve_lat_label_bokeh=all_curve_lat_label_bokeh,\
                                                          all_curve_lon_label_bokeh=all_curve_lon_label_bokeh), code="""
    
    var inds = cb_obj.indices
    
    selected_dot_on_map_bokeh.data['index'] = [inds]
    
    var new_slice = map_data_one_slice_bokeh.data
    
    selected_dot_on_map_bokeh.data['lat'] = [new_slice['map_lat_list'][inds]]
    selected_dot_on_map_bokeh.data['lon'] = [new_slice['map_lon_list'][inds]]
    selected_dot_on_map_bokeh.data['color'] = [new_slice['map_data_one_slice'][inds]]
    
    selected_dot_on_map_bokeh.change.emit()
    
    selected_curve_data_bokeh.data['curve_period'] = curve_data_all_bokeh.data['period_all'][inds]
    selected_curve_data_bokeh.data['curve_velocity'] = curve_data_all_bokeh.data['velocity_all'][inds]
    
    selected_curve_data_bokeh.change.emit()
    
    var all_lat_labels = all_curve_lat_label_bokeh.data['map_lat_label_list']
    var all_lon_labels = all_curve_lon_label_bokeh.data['map_lon_label_list']
    
    selected_curve_lat_label_bokeh.data['lat_label'] = [all_lat_labels[inds]]
    selected_curve_lon_label_bokeh.data['lon_label'] = [all_lon_labels[inds]]
    
    selected_curve_lat_label_bokeh.change.emit()
    selected_curve_lon_label_bokeh.change.emit()
    """)
    map_data_one_slice_bokeh.selected.js_on_change('indices',
                                                   map_data_one_slice_js)
    # ==============================
    period_slider_callback = CustomJS(args=dict(map_data_all_slices_bokeh=map_data_all_slices_bokeh,\
                                  map_data_one_slice_bokeh=map_data_one_slice_bokeh,\
                                  colorbar_data_all_slices_bokeh=colorbar_data_all_slices_bokeh, \
                                  colorbar_data_one_slice_bokeh=colorbar_data_one_slice_bokeh,\
                                  selected_dot_on_map_bokeh=selected_dot_on_map_bokeh,\
                                  map_data_one_slice_period_bokeh=map_data_one_slice_period_bokeh),\
                       code="""
    var p_index = Math.round(cb_obj.value)
    var map_data_all_slices = map_data_all_slices_bokeh.data
    
    
    var map_data_new_slice = map_data_all_slices['map_data_all_slices_color'][p_index]
    map_data_one_slice_bokeh.data['map_data_one_slice'] = map_data_new_slice
    map_data_one_slice_bokeh.change.emit()
    
    var color_data_all_slices = colorbar_data_all_slices_bokeh.data
    colorbar_data_one_slice_bokeh.data['colorbar_left'] = color_data_all_slices['colorbar_data_all_left'][p_index]
    colorbar_data_one_slice_bokeh.data['colorbar_right'] = color_data_all_slices['colorbar_data_all_right'][p_index]
    colorbar_data_one_slice_bokeh.change.emit()
    
    var selected_index = selected_dot_on_map_bokeh.data['index']
    selected_dot_on_map_bokeh.data['color'] = [map_data_new_slice[selected_index]]
    selected_dot_on_map_bokeh.change.emit()
    
    map_data_one_slice_period_bokeh.data['map_period'] = [map_data_all_slices['map_data_all_slices_period'][p_index]]
    map_data_one_slice_period_bokeh.change.emit()
    """)
    period_slider = Slider(start=0, end=nperiod-1, value=style_parameter['map_view_default_index'], \
                           step=1, title=style_parameter['period_slider_title'], \
                           width=style_parameter['period_slider_plot_width'],\
                           height=50)
    period_slider.js_on_change('value', period_slider_callback)
    period_slider_callback.args['period_index'] = period_slider
    # ==============================
    curve_slider_callback = CustomJS(args=dict(selected_dot_on_map_bokeh=selected_dot_on_map_bokeh,\
                                              map_data_one_slice_bokeh=map_data_one_slice_bokeh,\
                                              selected_curve_data_bokeh=selected_curve_data_bokeh,\
                                              curve_data_all_bokeh=curve_data_all_bokeh,\
                                              selected_curve_lat_label_bokeh=selected_curve_lat_label_bokeh,\
                                              selected_curve_lon_label_bokeh=selected_curve_lon_label_bokeh,\
                                              all_curve_lat_label_bokeh=all_curve_lat_label_bokeh,\
                                              all_curve_lon_label_bokeh=all_curve_lon_label_bokeh),\
                                    code="""
    var c_index = Math.round(cb_obj.value)
    
    var one_slice = map_data_one_slice_bokeh.data
    
    selected_dot_on_map_bokeh.data['index'] = [c_index]
    selected_dot_on_map_bokeh.data['lat'] = [one_slice['map_lat_list'][c_index]]
    selected_dot_on_map_bokeh.data['lon'] = [one_slice['map_lon_list'][c_index]]
    selected_dot_on_map_bokeh.data['color'] = [one_slice['map_data_one_slice'][c_index]]
    
    selected_dot_on_map_bokeh.change.emit()
    
    selected_curve_data_bokeh.data['curve_period'] = curve_data_all_bokeh.data['period_all'][c_index]
    selected_curve_data_bokeh.data['curve_velocity'] = curve_data_all_bokeh.data['velocity_all'][c_index]
    
    selected_curve_data_bokeh.change.emit()
    
    var all_lat_labels = all_curve_lat_label_bokeh.data['map_lat_label_list']
    var all_lon_labels = all_curve_lon_label_bokeh.data['map_lon_label_list']
    
    selected_curve_lat_label_bokeh.data['lat_label'] = [all_lat_labels[c_index]]
    selected_curve_lon_label_bokeh.data['lon_label'] = [all_lon_labels[c_index]]
    
    selected_curve_lat_label_bokeh.change.emit()
    selected_curve_lon_label_bokeh.change.emit()
    """)
    curve_slider = Slider(start=0, end=ncurve-1, value=style_parameter['curve_default_index'], \
                          step=1, title=style_parameter['curve_slider_title'], width=style_parameter['curve_plot_width'],\
                          height=50)
    curve_slider.js_on_change('value', curve_slider_callback)
    curve_slider_callback.args['curve_index'] = curve_slider
    # ==============================
    # annotating text
    annotating_fig01 = Div(text=style_parameter['annotating_html01'], \
        width=style_parameter['annotation_plot_width'], height=style_parameter['annotation_plot_height'])
    annotating_fig02 = Div(text=style_parameter['annotating_html02'],\
        width=style_parameter['annotation_plot_width'], height=style_parameter['annotation_plot_height'])
    # ==============================
    output_file(filename,
                title=style_parameter['html_title'],
                mode=style_parameter['library_source'])
    left_fig = Column(period_slider, map_view, colorbar_fig, annotating_fig01,\
                    width=style_parameter['left_column_width'] )
    right_fig = Column(curve_slider, curve_fig, annotating_fig02, \
                    width=style_parameter['right_column_width'] )
    layout = Row(left_fig, right_fig)
    save(layout)
Exemple #11
0
    def getPlotByYear(title,plotdict,classification,parameters,contextdict,maxval):

	category=str(contextdict["name"][0]).replace("'","").encode('ascii','ignore').decode("utf-8")
	classification=sorted(classification)

	if (category == 'All'):
	    hover = HoverTool(tooltips="""
	    <div style=" opacity: .8; padding: 5px; background-color: @type_color;color: @font_color;>
			      <h1 style="margin: 0; font-size: 12px;"> All museums:</h1>
			      <h1 style="margin: 0; font-size: 24px;"><strong> @counts </strong></h1>
			      </div>
			      """
			      )

	else:
	    hover = HoverTool(tooltips="""
	    <div style=" opacity: .8; padding: 5px; background-color: @type_color;color: @font_color;>
			      <h1 style="margin: 0; font-size: 12px;"> @classification</h1>
			      <h1 style="margin: 0; font-size: 24px;"><strong> @counts </strong></h1>
			      </div>
			      """
			      )

	bucketlen=0
	if ("Time" in classification):
	    bucketlen=len(plotdict["Time"])
	    classification.remove("Time")
        classlen=len(classification)

	for i, l in enumerate(classification):
	    classification[i]=bokutils.makeLegendKey(l)

	shortkeydict={}
        for key,val in plotdict.iteritems():
            shortkeydict[bokutils.makeLegendKey(key)]=plotdict[key]

	plotdict=shortkeydict
	
	    

			  
        my_palette=bokutils.getColors(classlen,True,False)
        
        colordict={}
        for i, val in enumerate(my_palette):
            colordict[i]=val

        
	valuearr=[]
	for year in range(0,bucketlen):
	    counts=[]
	    for c in classification:
	        if (c in plotdict):
		    counts.append(int(plotdict[c][year]+float(0.5)))
		else:
		    counts.append(0)
            valuearr.append(counts)

        plotdict=None

	labellist=[]
	for l in counts:
	    labellist.append(str(l))
	
        

        source = ColumnDataSource(data=dict(classification=classification,
					    counts=counts,
					    labellist=labellist,
                                            type_color=[colordict[x] for x in colordict],
                                            font_color=[bokutils.contrasting_text_color(colordict[x]) for x in colordict]))
        
        allsource = ColumnDataSource(data=dict(plotlist=valuearr))
	if (bokutils.ACCUMULATE_TRUE in parameters):
	    acc=True
	if (bokutils.ACCUMULATE_FALSE in parameters):
	    acc=False
	    
        paramsource = ColumnDataSource(data=dict(params=[acc,False,"Time"]))

	plotwidth=bokutils.PLOT_WIDTH



        p = figure(x_range=classification, plot_height=bokutils.PLOT_HEIGHT,
	           plot_width=plotwidth,
		   tools=[hover],
		   toolbar_location=None,
		   title=title)

        # set fixed padding for small numbers, but could also be percentage
	if (len(classification) < 6):
	   p.x_range.range_padding = 0.7

        p.vbar(x='classification', top='counts',width=0.9,line_width=5.9 , source=source, legend="classification",
               line_color='white', fill_color=factor_cmap('classification', palette=my_palette, factors=classification))
        
        p.xgrid.grid_line_color = None
        p.y_range.start = 0
	#p.x_range.factors = classification[::-1]
	#p.y_range = Range1d(start=0,end=4000,bounds='auto')
        p.y_range.end = maxval+int(0.05*float(maxval))

	blabel = LabelSet(x='classification', y='counts', text='labellist', level='glyph',
			  x_offset=-11, y_offset=2, source=source, render_mode='canvas')

	p.add_layout(blabel)


        p.legend.orientation = "vertical"
        p.xaxis.major_label_orientation = 1.2
	new_legend = p.legend[0]
	p.legend[0].plot = None
	p.add_layout(new_legend, 'right')

	slider = Slider(start=bokutils.FIRST_YEAR,
	                end=bokutils.LAST_YEAR,
                        value=bokutils.LAST_YEAR,
                        step=1,
                        title='Year',
			bar_color="#b3cccc")

	


	
	callback = CustomJS(args=dict(source=source,slider=slider,plot=p,window=None,source2=allsource,source3=paramsource),
	code="""
	var cb;
	cb = function (source,
		       slider,
		       plot,
		       window,
		       source2,
		       source3)
        {
	  var al, arr, data, end, i;
	  source = (source === undefined) ? source: source;
	  //console.log("slider "+slider);
	  slider = (slider === undefined) ? slider: slider;
	  plot = (plot === undefined) ? p: plot;
	  window = (window === undefined) ? null: window;
	  data = source.data;
	  
	  var arr = source.data["counts"];
	  var labels = source.data["labellist"];
	  var allcounts=source2.data["plotlist"];
	  var al = arr.length;
	  //console.log(al);
	  var i =0;
	    while (i < al)
	      {
	        arr[i] = 0;
	        i = i + 1;
	      }
	    i =0;
	   var startidx=slider.value - 1960;
	   //console.log("START"+startidx);

	    // No accumulation
	    i=0;
	    while (i < al)
	      {
	        arr[i] = arr[i] = allcounts[startidx][i];
		labels[i]=arr[i]+"";
	        i = i + 1;
	      }
	  	    

	source.change.emit();

	return null;
	};
	cb(source, slider, plot,window,source2,source3);

	""")

	slider.js_on_change('value', callback)


        wb=widgetbox(children=[slider], sizing_mode='scale_width')
        thisrow=Column(children=[p, wb],sizing_mode='scale_both')

        return thisrow
Exemple #12
0
def main():
    print('''Please select the CSV dataset you\'d like to use.
    The dataset should contain these columns:
        - metric to apply threshold to
        - indicator of event to detect (e.g. malicious activity)
            - Please label this as 1 or 0 (true or false); 
            This will not work otherwise!
    ''')
    # Import the dataset
    imported_data = None
    while isinstance(imported_data, pd.DataFrame) == False:
        file_path = input('Enter the path of your dataset: ')
        imported_data = file_to_df(file_path)

    time.sleep(1)

    print(f'''\nGreat! Here is a preview of your data:
Imported fields:''')
    # List headers by column index.
    cols = list(imported_data.columns)
    for index in range(len(cols)):
        print(f'{index}: {cols[index]}')
    print(f'Number of records: {len(imported_data.index)}\n')
    # Preview the DataFrame
    time.sleep(1)
    print(imported_data.head(), '\n')

    # Prompt for the metric and source of truth.
    time.sleep(1)
    metric_col, indicator_col = columns_picker(cols)
    # User self-validation.
    col_check = input('Can you confirm if this is correct? (y/n): ').lower()
    # If it's wrong, let them try again
    while col_check != 'y':
        metric_col, indicator_col = columns_picker(cols)
        col_check = input(
            'Can you confirm if this is correct? (y/n): ').lower()
    else:
        print(
            '''\nGreat! Thanks for your patience. Generating summary stats now..\n'''
        )

    # Generate summary stats.
    time.sleep(1)
    malicious, normal = classification_split(imported_data, metric_col,
                                             indicator_col)
    mal_mean = malicious.mean()
    mal_stddev = malicious.std()
    mal_count = malicious.size
    mal_median = malicious.median()
    norm_mean = normal.mean()
    norm_stddev = normal.std()
    norm_count = normal.size
    norm_median = normal.median()

    print(f'''Normal vs Malicious Summary (metric = {metric_col}):
Normal:
-----------------------------
Observations: {round(norm_count, 2)}
Average: {round(norm_mean, 2)}
Median: {round(norm_median, 2)}
Standard Deviation: {round(norm_stddev, 2)}

Malicious:
-----------------------------
Observations: {round(mal_count, 2)}
Average: {round(mal_mean, 2)}
Median: {round(mal_median, 2)}
Standard Deviation: {round(mal_stddev, 2)}
''')

    # Insights and advisories
    # Provide the accuracy metrics of a generic threshold at avg + 3 std deviations
    generic_threshold = confusion_matrix(
        malicious, normal, threshold_calc(norm_mean, norm_stddev, 3))

    time.sleep(1)
    print(
        f'''A threshold at (average + 3x standard deviations) {metric_col} would result in:
    - True Positives (correctly identified malicious events: {generic_threshold['TP']:,}
    - False Positives (wrongly identified normal events: {generic_threshold['FP']:,}
    - True Negatives (correctly identified normal events: {generic_threshold['TN']:,}
    - False Negatives (wrongly identified malicious events: {generic_threshold['FN']:,}

    Accuracy Metrics:
    - Precision (what % of events above threshold are actually malicious): {round(generic_threshold['precision'] * 100, 1)}%
    - Recall (what % of malicious events did we catch): {round(generic_threshold['recall'] * 100, 1)}%
    - F1 Score (blends precision and recall): {round(generic_threshold['f1_score'] * 100, 1)}%'''
    )

    # Distribution skew check.
    if norm_mean >= (norm_median * 1.1):
        time.sleep(1)
        print(
            f'''\nYou may want to be cautious as your normal traffic\'s {metric_col} 
has a long tail towards high values. The median is {round(norm_median, 2)} 
compared to {round(norm_mean, 2)} for the average.''')

    if mal_mean < threshold_calc(norm_mean, norm_stddev, 2):
        time.sleep(1)
        print(
            f'''\nWarning: you may find it difficult to avoid false positives as the average
{metric_col} for malicious traffic is under the 95th percentile of the normal traffic.'''
        )

    # For fun/anticipation. Actually a nerd joke because of the method we'll be using.
    if '-q' not in sys.argv[1:]:
        time.sleep(1)
        play_a_game.billy()
        decision = input('yes/no: ').lower()
        while decision != 'yes':
            time.sleep(1)
            print('...That\'s no fun...')
            decision = input('Let\'s try that again: ').lower()

    # Let's get to the simulations!
    time.sleep(1)
    print('''\nInstead of manually experimenting with threshold multipliers, 
let\'s simulate a range of options and see what produces the best result. 
This is similar to what is known as \"Monte Carlo simulation\".\n''')

    # Initialize session name & create app folder if there isn't one.
    time.sleep(1)
    session_name = input('Please provide a name for this project/session: ')
    session_folder = make_folder(session_name)

    # Generate list of multipliers to iterate over.
    time.sleep(1)
    mult_start = float(
        input(
            'Please provide the minimum multiplier you want to start at. We recommend 2: '
        ))
    # Set the max to how many std deviations away the sample max is.
    mult_end = (imported_data[metric_col].max() - norm_mean) / norm_stddev
    mult_interval = float(
        input('Please provide the desired gap between multiplier options: '))
    # range() only allows integers, let's manually populate a list
    multipliers = []
    mult_counter = mult_start
    while mult_counter < mult_end:
        multipliers.append(round(mult_counter, 2))
        mult_counter += mult_interval
    print('Generating simulations..\n')

    # Run simulations using our multipliers.
    simulations = monte_carlo(malicious, normal, norm_mean, norm_stddev,
                              multipliers)
    print('Done!')
    time.sleep(1)

    # Save simulations as CSV for later use.
    simulation_filepath = os.path.join(
        session_folder, f'{session_name}_simulation_results.csv')
    simulations.to_csv(simulation_filepath, index=False)
    print(f'Saved results to: {simulation_filepath}')
    # Find the first threshold with the highest F1 score.
    # This provides a balanced approach between precision and recall.
    f1_max = simulations[simulations.f1_score ==
                         simulations.f1_score.max()].head(1)
    f1_max_mult = f1_max.squeeze()['multiplier']
    time.sleep(1)
    print(
        f'''\nBased on the F1 score metric, setting a threshold at {round(f1_max_mult,1)} standard deviations
above the average magnitude might provide optimal results.\n''')
    time.sleep(1)
    print(f'''{f1_max}

We recommend that you skim the CSV and the following visualization outputs 
to sanity check results and make your own judgement.
''')

    # Now for the fun part..generating the visualizations via Bokeh.

    # Header & internal CSS.
    title_text = '''
    <style>

    @font-face {
        font-family: RobotoBlack;
        src: url(fonts/Roboto-Black.ttf);
        font-weight: bold;
    }

    
     @font-face {
        font-family: RobotoBold;
        src: url(fonts/Roboto-Bold.ttf);
        font-weight: bold;
    }   
    
    @font-face {
        font-family: RobotoRegular;
        src: url(fonts/Roboto-Regular.ttf);
    }

    body {
        background-color: #f2ebe6;
    }

    title_header {
        font-size: 80px;
        font-style: bold;
        font-family: RobotoBlack, Helvetica;
        font-weight: bold;
        margin-bottom: -200px;
    }

    h1, h2, h3 {
        font-family: RobotoBlack, Helvetica;
        color: #313596;
    }

    p {
        font-size: 12px;
        font-family: RobotoRegular
    }

    b {
        color: #58c491;
    }

    th, td {
        text-align:left;
        padding: 5px;
    }

    tr:nth-child(even) {
        background-color: white;
        opacity: .7;
    }

    .vertical { 
        border-left: 1px solid black; 
        height: 190px; 
            } 
    </style>

        <title_header style="text-align:left; color: white;">
            Cream.
        </title_header>
        <p style="font-family: RobotoBold, Helvetica;
        font-size:18px;
        margin-top: 0px;
        margin-left: 5px;">
            Because time is money, and <b style="font-size=18px;">"Cash Rules Everything Around Me"</b>.
        </p>
    </div>
    '''

    title_div = Div(text=title_text,
                    width=800,
                    height=160,
                    margin=(40, 0, 0, 70))

    # Summary stats from earlier.
    summary_text = f'''
    <h1>Results Overview</h1> 
    <i>metric = magnitude</i>

    <table style="width:100%">
      <tr>
        <th>Metric</th>
        <th>Normal Events</th>
        <th>Malicious Events</th>
      </tr>
      <tr>
        <td>Observations</td>
        <td>{norm_count:,}</td>
        <td>{mal_count:,}</td>
      </tr>
      <tr>
        <td>Average</td>
        <td>{round(norm_mean, 2):,}</td>
        <td>{round(mal_mean, 2):,}</td>
      </tr>
      <tr>
        <td>Median</td>
        <td>{round(norm_median, 2):,}</td>
        <td>{round(mal_median, 2):,}</td>
      </tr> 
      <tr>
        <td>Standard Deviation</td>
        <td>{round(norm_stddev, 2):,}</td>
        <td>{round(mal_stddev, 2):,}</td>
      </tr> 
    </table>
    '''

    summary_div = Div(text=summary_text,
                      width=470,
                      height=320,
                      margin=(3, 0, -70, 73))

    # Results of the hypothetical threshold.
    hypothetical = f'''
    <h1>"Rule of thumb" Hypothetical Threshold</h1>
    <p>A threshold at <i>(average + 3x standard deviations)</i> {metric_col} would result in:</p>
    <ul>
        <li>True Positives (correctly identified malicious events: 
            <b>{generic_threshold['TP']:,}</b></li>
        <li>False Positives (wrongly identified normal events:
            <b>{generic_threshold['FP']:,}</b></li>
        <li>True Negatives (correctly identified normal events: 
            <b>{generic_threshold['TN']:,}</b></li>
        <li>False Negatives (wrongly identified malicious events: 
            <b>{generic_threshold['FN']:,}</b></li>
    </ul>
    <h2>Accuracy Metrics</h2>
    <ul>
        <li>Precision (what % of events above threshold are actually malicious): 
            <b>{round(generic_threshold['precision'] * 100, 1)}%</b></li>
        <li>Recall (what % of malicious events did we catch): 
            <b>{round(generic_threshold['recall'] * 100, 1)}%</b></li>
        <li>F1 Score (blends precision and recall): 
            <b>{round(generic_threshold['f1_score'] * 100, 1)}%</b></li>
    </ul>
    '''

    hypo_div = Div(text=hypothetical,
                   width=600,
                   height=320,
                   margin=(5, 0, -70, 95))

    line = '''
    <div class="vertical"></div>
    '''
    vertical_line = Div(text=line,
                        width=20,
                        height=320,
                        margin=(80, 0, -70, -10))

    # Let's get the exploratory charts generated.

    malicious_hist, malicious_edge = np.histogram(malicious, bins=100)
    mal_hist_df = pd.DataFrame({
        'metric': malicious_hist,
        'left': malicious_edge[:-1],
        'right': malicious_edge[1:]
    })

    normal_hist, normal_edge = np.histogram(normal, bins=100)
    norm_hist_df = pd.DataFrame({
        'metric': normal_hist,
        'left': normal_edge[:-1],
        'right': normal_edge[1:]
    })

    exploratory = figure(
        plot_width=plot_width,
        plot_height=plot_height,
        sizing_mode='fixed',
        title=f'{metric_col.capitalize()} Distribution (σ = std dev)',
        x_axis_label=f'{metric_col.capitalize()}',
        y_axis_label='Observations')

    exploratory.title.text_font_size = title_font_size
    exploratory.border_fill_color = cell_bg_color
    exploratory.border_fill_alpha = cell_bg_alpha
    exploratory.background_fill_color = cell_bg_color
    exploratory.background_fill_alpha = plot_bg_alpha
    exploratory.min_border_left = left_border
    exploratory.min_border_right = right_border
    exploratory.min_border_top = top_border
    exploratory.min_border_bottom = bottom_border

    exploratory.quad(bottom=0,
                     top=mal_hist_df.metric,
                     left=mal_hist_df.left,
                     right=mal_hist_df.right,
                     legend_label='malicious',
                     fill_color=malicious_color,
                     alpha=.85,
                     line_alpha=.35,
                     line_width=.5)
    exploratory.quad(bottom=0,
                     top=norm_hist_df.metric,
                     left=norm_hist_df.left,
                     right=norm_hist_df.right,
                     legend_label='normal',
                     fill_color=normal_color,
                     alpha=.35,
                     line_alpha=.35,
                     line_width=.5)

    exploratory.add_layout(
        Arrow(end=NormalHead(fill_color=malicious_color, size=10,
                             line_alpha=0),
              line_color=malicious_color,
              x_start=mal_mean,
              y_start=mal_count,
              x_end=mal_mean,
              y_end=0))
    arrow_label = Label(x=mal_mean,
                        y=mal_count,
                        y_offset=5,
                        text='Malicious Events',
                        text_font_style='bold',
                        text_color=malicious_color,
                        text_font_size='10pt')

    exploratory.add_layout(arrow_label)
    exploratory.xaxis.formatter = NumeralTickFormatter(format='0,0')
    exploratory.yaxis.formatter = NumeralTickFormatter(format='0,0')

    # 3 sigma reference line
    sigma_ref(exploratory, norm_mean, norm_stddev)

    exploratory.legend.location = "top_right"
    exploratory.legend.background_fill_alpha = .3

    # Zoomed in version
    overlap_view = figure(
        plot_width=plot_width,
        plot_height=plot_height,
        sizing_mode='fixed',
        title=f'Overlap Highlight',
        x_axis_label=f'{metric_col.capitalize()}',
        y_axis_label='Observations',
        y_range=(0, mal_count * .33),
        x_range=(norm_mean + (norm_stddev * 2.5), mal_mean + (mal_stddev * 3)),
    )

    overlap_view.title.text_font_size = title_font_size
    overlap_view.border_fill_color = cell_bg_color
    overlap_view.border_fill_alpha = cell_bg_alpha
    overlap_view.background_fill_color = cell_bg_color
    overlap_view.background_fill_alpha = plot_bg_alpha
    overlap_view.min_border_left = left_border
    overlap_view.min_border_right = right_border
    overlap_view.min_border_top = top_border
    overlap_view.min_border_bottom = bottom_border

    overlap_view.quad(bottom=0,
                      top=mal_hist_df.metric,
                      left=mal_hist_df.left,
                      right=mal_hist_df.right,
                      legend_label='malicious',
                      fill_color=malicious_color,
                      alpha=.85,
                      line_alpha=.35,
                      line_width=.5)
    overlap_view.quad(bottom=0,
                      top=norm_hist_df.metric,
                      left=norm_hist_df.left,
                      right=norm_hist_df.right,
                      legend_label='normal',
                      fill_color=normal_color,
                      alpha=.35,
                      line_alpha=.35,
                      line_width=.5)
    overlap_view.xaxis.formatter = NumeralTickFormatter(format='0,0')
    overlap_view.yaxis.formatter = NumeralTickFormatter(format='0,0')

    sigma_ref(overlap_view, norm_mean, norm_stddev)

    overlap_view.legend.location = "top_right"
    overlap_view.legend.background_fill_alpha = .3

    # Probability Density - bigger bins for sparser malicous observations
    malicious_hist_dense, malicious_edge_dense = np.histogram(malicious,
                                                              density=True,
                                                              bins=50)
    mal_hist_dense_df = pd.DataFrame({
        'metric': malicious_hist_dense,
        'left': malicious_edge_dense[:-1],
        'right': malicious_edge_dense[1:]
    })

    normal_hist_dense, normal_edge_dense = np.histogram(normal,
                                                        density=True,
                                                        bins=100)
    norm_hist_dense_df = pd.DataFrame({
        'metric': normal_hist_dense,
        'left': normal_edge_dense[:-1],
        'right': normal_edge_dense[1:]
    })

    density = figure(plot_width=plot_width,
                     plot_height=plot_height,
                     sizing_mode='fixed',
                     title='Probability Density',
                     x_axis_label=f'{metric_col.capitalize()}',
                     y_axis_label='% of Group Total')

    density.title.text_font_size = title_font_size
    density.border_fill_color = cell_bg_color
    density.border_fill_alpha = cell_bg_alpha
    density.background_fill_color = cell_bg_color
    density.background_fill_alpha = plot_bg_alpha
    density.min_border_left = left_border
    density.min_border_right = right_border
    density.min_border_top = top_border
    density.min_border_bottom = bottom_border

    density.quad(bottom=0,
                 top=mal_hist_dense_df.metric,
                 left=mal_hist_dense_df.left,
                 right=mal_hist_dense_df.right,
                 legend_label='malicious',
                 fill_color=malicious_color,
                 alpha=.85,
                 line_alpha=.35,
                 line_width=.5)
    density.quad(bottom=0,
                 top=norm_hist_dense_df.metric,
                 left=norm_hist_dense_df.left,
                 right=norm_hist_dense_df.right,
                 legend_label='normal',
                 fill_color=normal_color,
                 alpha=.35,
                 line_alpha=.35,
                 line_width=.5)
    density.xaxis.formatter = NumeralTickFormatter(format='0,0')
    density.yaxis.formatter = NumeralTickFormatter(format='0.000%')

    sigma_ref(density, norm_mean, norm_stddev)

    density.legend.location = "top_right"
    density.legend.background_fill_alpha = .3

    # Simulation Series to be used
    false_positives = simulations.FP
    false_negatives = simulations.FN
    multiplier = simulations.multiplier
    precision = simulations.precision
    recall = simulations.recall
    f1_score = simulations.f1_score
    f1_max = simulations[simulations.f1_score == simulations.f1_score.max(
    )].head(1).squeeze()['multiplier']

    # False Positives vs False Negatives
    errors = figure(plot_width=plot_width,
                    plot_height=plot_height,
                    sizing_mode='fixed',
                    x_range=(multiplier.min(), multiplier.max()),
                    y_range=(0, false_positives.max()),
                    title='False Positives vs False Negatives',
                    x_axis_label='Multiplier',
                    y_axis_label='Count')

    errors.title.text_font_size = title_font_size
    errors.border_fill_color = cell_bg_color
    errors.border_fill_alpha = cell_bg_alpha
    errors.background_fill_color = cell_bg_color
    errors.background_fill_alpha = plot_bg_alpha
    errors.min_border_left = left_border
    errors.min_border_right = right_border
    errors.min_border_top = top_border
    errors.min_border_bottom = right_border

    errors.line(multiplier,
                false_positives,
                legend_label='false positives',
                line_width=2,
                color=fp_color)
    errors.line(multiplier,
                false_negatives,
                legend_label='false negatives',
                line_width=2,
                color=fn_color)
    errors.yaxis.formatter = NumeralTickFormatter(format='0,0')

    errors.extra_y_ranges = {"y2": Range1d(start=0, end=1.1)}
    errors.add_layout(
        LinearAxis(y_range_name="y2",
                   axis_label="Score",
                   formatter=NumeralTickFormatter(format='0.00%')), 'right')
    errors.line(multiplier,
                f1_score,
                line_width=2,
                color=f1_color,
                legend_label='F1 Score',
                y_range_name="y2")

    # F1 Score Maximization point
    f1_thresh = Span(location=f1_max,
                     dimension='height',
                     line_color=f1_color,
                     line_dash='dashed',
                     line_width=2)
    f1_label = Label(x=f1_max + .05,
                     y=180,
                     y_units='screen',
                     text=f'F1 Max: {round(f1_max,2)}',
                     text_font_size='10pt',
                     text_font_style='bold',
                     text_align='left',
                     text_color=f1_color)

    errors.add_layout(f1_thresh)
    errors.add_layout(f1_label)

    errors.legend.location = "top_right"
    errors.legend.background_fill_alpha = .3

    # False Negative Weighting.
    # Intro.
    weighting_intro = f'''
    <h3>Error types differ in impact.</h3> 
    <p>In the case of security incidents, a false negative, 
though possibly rarer than false positives, is likely more costly. For example, downtime suffered 
from a DDoS attack (lost sales/customers) incurs more loss than time wasted chasing a false positive 
(labor hours). </p>

<p>Try playing around with the slider to the right to see how your thresholding strategy might need to change 
depending on the relative weight of false negatives to false positives. What does it look like at
1:1, 50:1, etc.?</p>
'''

    weighting_div = Div(text=weighting_intro,
                        width=420,
                        height=180,
                        margin=(0, 75, 0, 0))

    # Now for the weighted errors viz

    default_weighting = 10
    initial_fp_cost = 100
    simulations['weighted_FN'] = simulations.FN * default_weighting
    weighted_fn = simulations.weighted_FN
    simulations[
        'total_weighted_error'] = simulations.FP + simulations.weighted_FN
    total_weighted_error = simulations.total_weighted_error
    simulations['fp_cost'] = initial_fp_cost
    fp_cost = simulations.fp_cost
    simulations[
        'total_estimated_cost'] = simulations.total_weighted_error * simulations.fp_cost
    total_estimated_cost = simulations.total_estimated_cost
    twe_min = simulations[simulations.total_weighted_error ==
                          simulations.total_weighted_error.min()].head(
                              1).squeeze()['multiplier']
    twe_min_count = simulations[simulations.multiplier == twe_min].head(
        1).squeeze()['total_weighted_error']
    generic_twe = simulations[simulations.multiplier.apply(
        lambda x: round(x, 2)) == 3.00].squeeze()['total_weighted_error']

    comparison = f'''
    <p>Based on your inputs, the optimal threshold is around <b>{twe_min}</b>.
    This would result in an estimated <b>{int(twe_min_count):,}</b> total weighted errors and 
    <b>${int(twe_min_count * initial_fp_cost):,}</b> in losses.</p>

    <p>The generic threshold of 3.0 standard deviations would result in <b>{int(generic_twe):,}</b> 
    total weighted errors and <b>${int(generic_twe * initial_fp_cost):,}</b> in losses.</p>

    <p>Using the optimal threshold would save <b>${int((generic_twe - twe_min_count) * initial_fp_cost):,}</b>, 
    reducing costs by <b>{(generic_twe - twe_min_count) / generic_twe * 100:.1f}%</b> 
    (assuming near-future events are distributed similarly to those from the past).</p>
    '''
    comparison_div = Div(text=comparison,
                         width=420,
                         height=230,
                         margin=(0, 75, 0, 0))

    loss_min = ColumnDataSource(data=dict(multiplier=multiplier,
                                          fp=false_positives,
                                          fn=false_negatives,
                                          weighted_fn=weighted_fn,
                                          twe=total_weighted_error,
                                          fpc=fp_cost,
                                          tec=total_estimated_cost,
                                          precision=precision,
                                          recall=recall,
                                          f1=f1_score))

    evaluation = Figure(plot_width=900,
                        plot_height=520,
                        sizing_mode='fixed',
                        x_range=(multiplier.min(), multiplier.max()),
                        title='Evaluation Metrics vs Total Estimated Cost',
                        x_axis_label='Multiplier',
                        y_axis_label='Cost')

    evaluation.title.text_font_size = title_font_size
    evaluation.border_fill_color = cell_bg_color
    evaluation.border_fill_alpha = cell_bg_alpha
    evaluation.background_fill_color = cell_bg_color
    evaluation.background_fill_alpha = plot_bg_alpha
    evaluation.min_border_left = left_border
    evaluation.min_border_right = right_border
    evaluation.min_border_top = top_border
    evaluation.min_border_bottom = bottom_border

    evaluation.line('multiplier',
                    'tec',
                    source=loss_min,
                    line_width=3,
                    line_alpha=0.6,
                    color=total_weighted_color,
                    legend_label='Total Estimated Cost')
    evaluation.yaxis.formatter = NumeralTickFormatter(format='$0,0')

    # Evaluation metrics on second right axis.
    evaluation.extra_y_ranges = {"y2": Range1d(start=0, end=1.1)}

    evaluation.add_layout(
        LinearAxis(y_range_name="y2",
                   axis_label="Score",
                   formatter=NumeralTickFormatter(format='0.00%')), 'right')
    evaluation.line('multiplier',
                    'precision',
                    source=loss_min,
                    line_width=3,
                    line_alpha=0.6,
                    color=precision_color,
                    legend_label='Precision',
                    y_range_name="y2")
    evaluation.line('multiplier',
                    'recall',
                    source=loss_min,
                    line_width=3,
                    line_alpha=0.6,
                    color=recall_color,
                    legend_label='Recall',
                    y_range_name="y2")
    evaluation.line('multiplier',
                    'f1',
                    source=loss_min,
                    line_width=3,
                    line_alpha=0.6,
                    color=f1_color,
                    legend_label='F1 score',
                    y_range_name="y2")
    evaluation.legend.location = "bottom_right"
    evaluation.legend.background_fill_alpha = .3

    twe_thresh = Span(location=twe_min,
                      dimension='height',
                      line_color=total_weighted_color,
                      line_dash='dashed',
                      line_width=2)
    twe_label = Label(x=twe_min - .05,
                      y=240,
                      y_units='screen',
                      text=f'Cost Min: {round(twe_min,2)}',
                      text_font_size='10pt',
                      text_font_style='bold',
                      text_align='right',
                      text_color=total_weighted_color)
    evaluation.add_layout(twe_thresh)
    evaluation.add_layout(twe_label)

    # Add in same f1 thresh as previous viz
    evaluation.add_layout(f1_thresh)
    evaluation.add_layout(f1_label)

    handler = CustomJS(args=dict(source=loss_min,
                                 thresh=twe_thresh,
                                 label=twe_label,
                                 comparison=comparison_div),
                       code="""
       var data = source.data
       var ratio = cb_obj.value
       var multiplier = data['multiplier']
       var fp = data['fp']
       var fn = data['fn']
       var weighted_fn = data['weighted_fn']
       var twe = data['twe']
       var fpc = data['fpc']
       var tec = data['tec']
       var generic_twe = 0
       
       function round(value, decimals) {
       return Number(Math.round(value+'e'+decimals)+'e-'+decimals);
       }
       
       function comma_sep(x) {
           return x.toString().replace(/\B(?<!\.\d*)(?=(\d{3})+(?!\d))/g, ",");
       }
       
       for (var i = 0; i < multiplier.length; i++) {
          weighted_fn[i] = Math.round(fn[i] * ratio)
          twe[i] = weighted_fn[i] + fp[i]
          tec[i] = twe[i] * fpc[i]
          if (round(multiplier[i],2) == 3.00) {
            generic_twe = twe[i]
          }
       }
              
       var min_loss = Math.min.apply(null,twe)
       var new_thresh = 0
       
       for (var i = 0; i < multiplier.length; i++) {
       if (twe[i] == min_loss) {
           new_thresh = multiplier[i]
           thresh.location = new_thresh
           thresh.change.emit()
           label.x = new_thresh
           label.text = `Cost Min: ${new_thresh}`
           label.change.emit()
           comparison.text = `
            <p>Based on your inputs, the optimal threshold is around <b>${new_thresh}</b>.
            This would result in an estimated <b>${comma_sep(round(min_loss,0))}</b> total weighted errors and 
            <b>$${comma_sep(round(min_loss * fpc[i],0))}</b> in losses.</p>
        
            <p>The generic threshold of 3.0 standard deviations would result in <b>${comma_sep(round(generic_twe,0))}</b> 
            total weighted errors and <b>$${comma_sep(round(generic_twe * fpc[i],0))}</b> in losses.</p>
        
            <p>Using the optimal threshold would save <b>$${comma_sep(round((generic_twe - min_loss) * fpc[i],0))}</b>, 
            reducing costs by <b>${comma_sep(round((generic_twe - min_loss) / generic_twe * 100,0))}%</b> 
            (assuming near-future events are distributed similarly to those from the past).</p>
           `
           comparison.change.emit()
         }
       }
       source.change.emit();
    """)

    slider = Slider(start=1.0,
                    end=500,
                    value=default_weighting,
                    step=.25,
                    title="FN:FP Ratio",
                    bar_color='#FFD100',
                    height=50,
                    margin=(5, 0, 5, 0))
    slider.js_on_change('value', handler)

    cost_handler = CustomJS(args=dict(source=loss_min,
                                      comparison=comparison_div),
                            code="""
           var data = source.data
           var new_cost = cb_obj.value
           var multiplier = data['multiplier']
           var fp = data['fp']
           var fn = data['fn']
           var weighted_fn = data['weighted_fn']
           var twe = data['twe']
           var fpc = data['fpc']
           var tec = data['tec']
           var generic_twe = 0
           
           function round(value, decimals) {
           return Number(Math.round(value+'e'+decimals)+'e-'+decimals);
           } 

           function comma_sep(x) {
               return x.toString().replace(/\B(?<!\.\d*)(?=(\d{3})+(?!\d))/g, ",");
           }
           
           for (var i = 0; i < multiplier.length; i++) {
              fpc[i] = new_cost
              tec[i] = twe[i] * fpc[i]
              if (round(multiplier[i],2) == 3.00) {
                generic_twe = twe[i]
              }
           }

           var min_loss = Math.min.apply(null,twe)
           var new_thresh = 0

           for (var i = 0; i < multiplier.length; i++) {
           if (twe[i] == min_loss) {
               new_thresh = multiplier[i]
               comparison.text = `
                <p>Based on your inputs, the optimal threshold is around <b>${new_thresh}</b>.
                This would result in an estimated <b>${comma_sep(round(min_loss,0))}</b> total weighted errors and 
                <b>$${comma_sep(round(min_loss * new_cost,0))}</b> in losses.</p>

                <p>The generic threshold of 3.0 standard deviations would result in 
                <b>${comma_sep(round(generic_twe,0))}</b> total weighted errors and 
                <b>$${comma_sep(round(generic_twe * new_cost,0))}</b> in losses.</p>

                <p>Using the optimal threshold would save 
                <b>$${comma_sep(round((generic_twe - min_loss) * new_cost,0))}</b>, 
                reducing costs by <b>${comma_sep(round((generic_twe - min_loss)/generic_twe * 100,0))}%</b> 
                (assuming near-future events are distributed similarly to those from the past).</p>
               `
               comparison.change.emit()
              }
           }
           source.change.emit();
        """)

    cost_input = TextInput(value=f"{initial_fp_cost}",
                           title="How much a false positive costs:",
                           height=75,
                           margin=(20, 75, 20, 0))
    cost_input.js_on_change('value', cost_handler)

    # Include DataTable of simulation results
    dt_columns = [
        TableColumn(field="multiplier", title="Multiplier"),
        TableColumn(field="fp",
                    title="False Positives",
                    formatter=NumberFormatter(format='0,0')),
        TableColumn(field="fn",
                    title="False Negatives",
                    formatter=NumberFormatter(format='0,0')),
        TableColumn(field="weighted_fn",
                    title="Weighted False Negatives",
                    formatter=NumberFormatter(format='0,0.00')),
        TableColumn(field="twe",
                    title="Total Weighted Errors",
                    formatter=NumberFormatter(format='0,0.00')),
        TableColumn(field="fpc",
                    title="Estimated FP Cost",
                    formatter=NumberFormatter(format='$0,0.00')),
        TableColumn(field="tec",
                    title="Estimated Total Cost",
                    formatter=NumberFormatter(format='$0,0.00')),
        TableColumn(field="precision",
                    title="Precision",
                    formatter=NumberFormatter(format='0.00%')),
        TableColumn(field="recall",
                    title="Recall",
                    formatter=NumberFormatter(format='0.00%')),
        TableColumn(field="f1",
                    title="F1 Score",
                    formatter=NumberFormatter(format='0.00%')),
    ]

    data_table = DataTable(source=loss_min,
                           columns=dt_columns,
                           width=1400,
                           height=700,
                           sizing_mode='fixed',
                           fit_columns=True,
                           reorderable=True,
                           sortable=True,
                           margin=(30, 0, 20, 0))

    # weighting_layout = column([weighting_div, evaluation, slider, data_table])
    weighting_layout = column(
        row(column(weighting_div, cost_input, comparison_div),
            column(slider, evaluation), Div(text='', height=200, width=60)),
        data_table)

    # Initialize visualizations in browser
    time.sleep(1.5)

    layout = grid([
        [title_div],
        [row(summary_div, vertical_line, hypo_div)],
        [
            row(Div(text='', height=200, width=60), exploratory,
                Div(text='', height=200, width=10), overlap_view,
                Div(text='', height=200, width=40))
        ],
        [Div(text='', height=10, width=200)],
        [
            row(Div(text='', height=200, width=60), density,
                Div(text='', height=200, width=10), errors,
                Div(text='', height=200, width=40))
        ],
        [Div(text='', height=10, width=200)],
        [
            row(Div(text='', height=200, width=60), weighting_layout,
                Div(text='', height=200, width=40))
        ],
    ])

    # Generate html resources for dashboard
    fonts = os.path.join(os.getcwd(), 'fonts')
    if os.path.isdir(os.path.join(session_folder, 'fonts')):
        shutil.rmtree(os.path.join(session_folder, 'fonts'))
        shutil.copytree(fonts, os.path.join(session_folder, 'fonts'))
    else:
        shutil.copytree(fonts, os.path.join(session_folder, 'fonts'))

    html = file_html(layout, INLINE, "Cream")
    with open(os.path.join(session_folder, f'{session_name}.html'),
              "w") as file:
        file.write(html)
    webbrowser.open("file://" +
                    os.path.join(session_folder, f'{session_name}.html'))
Exemple #13
0
def bokeh_xy_sliders(outputMatrix):
    # outputArray is a 2d array composed of
    # a transposed array of unique group names, competitiveness scores,
    # comp p-values, lottery scores, and lottery p-values.
    # Scatterplot with sliders to filter significance
    output_file("scatterplot.html")

    TOOLS = ("hover,pan,wheel_zoom,zoom_in,zoom_out,box_zoom,reset, save, tap")

    TOOLTIPS = [  # Displays on hover
        ("group", "@groups"), ("Comp P-value", "@compP{0.0000}"),
        ("Lottery P-value", "@lotP{0.0000}"), ("(x,y)", "(@x, @y)")
    ]

    # create a new plot with a title and axis labels
    p = figure(tools=TOOLS,
               title="Scatterplot",
               x_axis_label='Lottery Score',
               y_axis_label='Competitiveness',
               tooltips=TOOLTIPS,
               output_backend="webgl",
               toolbar_location="above")

    # Making sliders
    slider = Slider(start=0.,
                    end=1.,
                    value=1.,
                    step=.01,
                    title="Comp P-value Filter")
    slider2 = Slider(start=0.,
                     end=1.,
                     value=1.,
                     step=.01,
                     title="Lottery P-value Filter")

    try:
        colors = viridis(len(outputMatrix.uniqueGroups))
    except ValueError:
        print("Error: Bokeh can only display 256 colors. You requested ",
              len(outputMatrix.uniqueGroups),
              " Please try chart style 2 instead.")
    for i in range(len(outputMatrix.uniqueGroups)):
        # Assigning data to each point
        dx = [outputMatrix.lotScores[i]]
        dy = [outputMatrix.compScores[i]]
        dcompP = [outputMatrix.compPValues[i]]
        dlotP = [outputMatrix.lotPValues[i]]
        dgroup = [outputMatrix.uniqueGroups[i]]
        dcolor = colors[i]

        source = ColumnDataSource(data=dict(x=dx,
                                            y=dy,
                                            groups=dgroup,
                                            compP=dcompP,
                                            lotP=dlotP,
                                            size=[15],
                                            color=[dcolor]))
        # Callback for when the slider is changed
        callback = CustomJS(args=dict(source=source),
                            code="""
            source.change.emit();
        """)
        slider.js_on_change('value', callback)
        slider2.js_on_change('value', callback)

        # Custom Javascript boolean filters
        js_filter = CustomJSFilter(args=dict(slider=slider, source=source),
                                   code="""
            bools = []
            for (i = 0; i < source.data.x.length; i++) {
                if (source.data.compP[i] < slider.value) {
                    bools.push(true);
                }
                else {
                    bools.push(false);
                }
            }

            return bools;
        """)
        js_filter2 = CustomJSFilter(args=dict(slider=slider2, source=source),
                                    code="""
            bools = []
            for (i = 0; i < source.data.x.length; i++) {
                if (source.data.lotP[i] < slider.value) {
                    bools.push(true);
                }
                else {
                    bools.push(false);
                }
            }

            return bools;
        """)

        # Using those filters to change what data is displayed
        view = CDSView(source=source, filters=[js_filter, js_filter2])
        p.circle(x="x",
                 y="y",
                 source=source,
                 size="size",
                 legend="groups",
                 color="color",
                 view=view)

    p.legend.click_policy = "hide"  # Click legend entry to hide that point

    sliders = column(slider, slider2)
    layout = row(p, sliders)

    show(layout)
Exemple #14
0
def plot_linked_heatmap_weights(new_df,
                                x_key=None,
                                y_key=None,
                                save_path=None,
                                add_t_slider=True,
                                RF_plot_title='Corresponding spatial RFs'):
    # new_df = new_df.sort_index()

    #sort the rows of interest
    if x_key is None and y_key is None:
        x_tick_labels = new_df.sort_values(
            'x', ascending=False)['x'].unique().tolist()
        y_tick_labels = new_df.sort_values(
            'y', ascending=True)['y'].unique().tolist()
    else:
        x_tick_labels = new_df.sort_values(
            x_key, ascending=True)['x'].unique().tolist()
        y_tick_labels = new_df.sort_values(
            y_key, ascending=False)['y'].unique().tolist()

    #make the heatmap plot
    s1 = ColumnDataSource(data=new_df)
    p1 = figure(title="Heatmap",
                tools="hover",
                toolbar_location='above',
                x_range=x_tick_labels,
                y_range=y_tick_labels,
                x_axis_label="Log_10 of L1 regularisation on weights",
                y_axis_label='Number of hidden units',
                width=600,
                height=600)  #, output_backend="webgl")
    r1 = p1.rect('x', 'y', color='color', source=s1, width=1, height=1)

    #add the colorbar and its label
    colormap = cm.get_cmap("jet")  #choose any matplotlib colormap here
    bokehpalette = [
        matplotlib.colors.rgb2hex(m) for m in colormap(np.arange(colormap.N))
    ]
    color_mapper = LinearColorMapper(palette=bokehpalette,
                                     low=np.min(new_df.z_value.values),
                                     high=np.max(new_df.z_value.values))

    color_bar = ColorBar(color_mapper=color_mapper, location=(1, 1))  #,
    # label_standoff=12, border_line_color=None, location=(0,0))
    c_bar_plot = figure(title="Final validation loss",
                        title_location="right",
                        height=611,
                        width=140,
                        toolbar_location=None,
                        min_border=0,
                        outline_line_color=None)
    c_bar_plot.add_layout(color_bar, 'right')
    c_bar_plot.title.align = "center"

    s2 = ColumnDataSource(data=dict(url=[], x=[], y=[], dw=[], dh=[]))
    s3 = ColumnDataSource(data=new_df)

    p2 = figure(title=RF_plot_title,
                tools=["zoom_in", "zoom_out", "pan", "reset"],
                x_range=[str(x) for x in np.arange(0, 10)],
                y_range=[str(y) for y in np.arange(0, 10)],
                width=600,
                height=600,
                min_border_left=40)

    # turn off x-axis major ticks
    p2.xaxis.major_tick_line_color = None
    # turn off y-axis major ticks
    p2.yaxis.major_tick_line_color = None
    r2 = p2.image_url(url='url',
                      x=0,
                      y=0,
                      w='dw',
                      h='dh',
                      source=s2,
                      anchor="bottom_left")

    p1.title.text_font_size = '20pt'
    p1.xaxis.axis_label_text_font_size = "20pt"
    p1.xaxis.major_label_text_font_size = "12pt"
    p1.yaxis.major_label_text_font_size = "12pt"
    p1.yaxis.axis_label_text_font_size = "20pt"
    p2.yaxis.major_label_text_font_size = "0pt"
    p2.xaxis.major_label_text_font_size = "0pt"
    p2.title.text_font_size = '20pt'
    c_bar_plot.title.text_font_size = '16pt'

    if add_t_slider:
        # slider, slider_callback_code = get_t_slider(s2, s3)
        slider = Slider(start=1,
                        end=7,
                        value=7,
                        step=1,
                        title="Time step",
                        callback_policy='continuous',
                        orientation='vertical',
                        height=200,
                        bar_color='grey')

        slider_callback_code = """
            //check if the slider value has changed. 
            //If it has, update the timestep of the image being displayed
            var value = cb_obj.value
            if (value === parseInt(value, 10)) {
                //console.log(value)
                //console.log(s2.data)
                var old_url = s2.data.url
                //console.log(old_url)

                if (old_url != []) {
                  old_url=old_url[0]
                  //console.log(old_url)
                  var split_url = old_url.split("=");
                  split_url[1] = split_url[1].replace(split_url[1][0], value);
                  new_url = split_url[0]+"="+split_url[1];
                  var d2 = s2.data;
                  d2['url'] = [new_url]
                  s2.change.emit();
                } 
            }
            """

        def display_event(s2=s2, s3=s3):
            return CustomJS(args=dict(s2=s2, s3=s3), code=slider_callback_code)

        # slider_callback = CustomJS(args=dict(s2=s2, s3=s3), code=slider_callback_code)
        slider.js_on_change('value', display_event())

    else:
        slider = None
        slider_callback_code = ""

    hover = p1.select_one(HoverTool)
    hover.tooltips = None
    hover_callback_code = """
    //console.log(cb_data)
    var indices = cb_data.index['1d'].indices;
    // console.log(indices)
    if (indices.length > 0) {
        //console.log(indices[0])
        //console.log('here!!!')
        //var imgWidth = image.width || image.naturalWidth;
        //var imgHeight = image.height || image.naturalHeight;
        //var img = s3.data.image[indices[0]];
        var url = s3.data.url[indices[0]];
        //console.log(s3.data.image[0][0])
        //console.log(img)

        if (slider != null) {
            var value = slider.value
            var old_url = url
            //console.log(old_url)

            if (old_url != []) {
                if (value === parseInt(value, 10)) {
                    //console.log(value)
                    //console.log(old_url)
                    var split_url = old_url.split("=");
                    split_url[1] = split_url[1].replace(split_url[1][0], value);
                    new_url = split_url[0]+"="+split_url[1];
                    url=new_url
                }
            }
        }

        var d2 = s2.data;
        d2['url'] = [url]
        d2['x'] = [0]
        d2['y'] = [0]
        d2['dw'] = [10]
        d2['dh'] = [10]
        
        s2.change.emit();
        }
    """

    hover.callback = CustomJS(args=dict(s2=s2, s3=s3, slider=slider),
                              code=hover_callback_code)

    if slider is not None:
        layout = row(p1, c_bar_plot, p2, slider)
    else:
        layout = row(p1, c_bar_plot, p2)

    if save_path is not None:
        output_file(save_path)
    show(layout)
Exemple #15
0
class ViewerWidgets(object):
    """ 
    Encapsulates Bokeh widgets, and related callbacks, that are part of prospect's GUI.
        Except for VI widgets
    """
    
    def __init__(self, plots, nspec):
        self.js_files = get_resources('js')
        self.navigation_button_width = 30
        self.z_button_width = 30
        self.plot_widget_width = (plots.plot_width+(plots.plot_height//2))//2 - 40 # used for widgets scaling
    
        #-----
        #- Ifiberslider and smoothing widgets
        # Ifiberslider's value controls which spectrum is displayed
        # These two widgets call update_plot(), later defined
        slider_end = nspec-1 if nspec > 1 else 0.5 # Slider cannot have start=end
        self.ifiberslider = Slider(start=0, end=slider_end, value=0, step=1, title='Spectrum (of '+str(nspec)+')')
        self.smootherslider = Slider(start=0, end=26, value=0, step=1.0, title='Gaussian Sigma Smooth')
        self.coaddcam_buttons = None
        self.model_select = None


    def add_navigation(self, nspec):
        #-----
        #- Navigation buttons
        self.prev_button = Button(label="<", width=self.navigation_button_width)
        self.next_button = Button(label=">", width=self.navigation_button_width)
        self.prev_callback = CustomJS(
            args=dict(ifiberslider=self.ifiberslider),
            code="""
            if(ifiberslider.value>0 && ifiberslider.end>=1) {
                ifiberslider.value--
            }
            """)
        self.next_callback = CustomJS(
            args=dict(ifiberslider=self.ifiberslider, nspec=nspec),
            code="""
            if(ifiberslider.value<nspec-1 && ifiberslider.end>=1) {
                ifiberslider.value++
            }
            """)
        self.prev_button.js_on_event('button_click', self.prev_callback)
        self.next_button.js_on_event('button_click', self.next_callback)

    def add_resetrange(self, viewer_cds, plots):
        #-----
        #- Axis reset button (superseeds the default bokeh "reset"
        self.reset_plotrange_button = Button(label="Reset X-Y range", button_type="default")
        reset_plotrange_code = self.js_files["adapt_plotrange.js"] + self.js_files["reset_plotrange.js"]
        self.reset_plotrange_callback = CustomJS(args = dict(fig=plots.fig, xmin=plots.xmin, xmax=plots.xmax, spectra=viewer_cds.cds_spectra),
                                            code = reset_plotrange_code)
        self.reset_plotrange_button.js_on_event('button_click', self.reset_plotrange_callback)


    def add_redshift_widgets(self, z, viewer_cds, plots):
        ## TODO handle "z" (same issue as viewerplots TBD)

        #-----
        #- Redshift / wavelength scale widgets
        z1 = np.floor(z*100)/100
        dz = z-z1
        self.zslider = Slider(start=-0.1, end=5.0, value=z1, step=0.01, title='Redshift rough tuning')
        self.dzslider = Slider(start=0.0, end=0.0099, value=dz, step=0.0001, title='Redshift fine-tuning')
        self.dzslider.format = "0[.]0000"
        self.z_input = TextInput(value="{:.4f}".format(z), title="Redshift value:")

        #- Observer vs. Rest frame wavelengths
        self.waveframe_buttons = RadioButtonGroup(
            labels=["Obs", "Rest"], active=0)

        self.zslider_callback  = CustomJS(
            args=dict(zslider=self.zslider, dzslider=self.dzslider, z_input=self.z_input),
            code="""
            // Protect against 1) recursive call with z_input callback;
            //   2) out-of-range zslider values (should never happen in principle)
            var z1 = Math.floor(parseFloat(z_input.value)*100) / 100
            if ( (Math.abs(zslider.value-z1) >= 0.01) &&
                 (zslider.value >= -0.1) && (zslider.value <= 5.0) ){
                 var new_z = zslider.value + dzslider.value
                 z_input.value = new_z.toFixed(4)
                }
            """)

        self.dzslider_callback  = CustomJS(
            args=dict(zslider=self.zslider, dzslider=self.dzslider, z_input=self.z_input),
            code="""
            var z = parseFloat(z_input.value)
            var z1 = Math.floor(z) / 100
            var z2 = z-z1
            if ( (Math.abs(dzslider.value-z2) >= 0.0001) &&
                 (dzslider.value >= 0.0) && (dzslider.value <= 0.0099) ){
                 var new_z = zslider.value + dzslider.value
                 z_input.value = new_z.toFixed(4)
                }
            """)

        self.zslider.js_on_change('value', self.zslider_callback)
        self.dzslider.js_on_change('value', self.dzslider_callback)

        self.z_minus_button = Button(label="<", width=self.z_button_width)
        self.z_plus_button = Button(label=">", width=self.z_button_width)
        self.z_minus_callback = CustomJS(
            args=dict(z_input=self.z_input),
            code="""
            var z = parseFloat(z_input.value)
            if(z >= -0.09) {
                z -= 0.01
                z_input.value = z.toFixed(4)
            }
            """)
        self.z_plus_callback = CustomJS(
            args=dict(z_input=self.z_input),
            code="""
            var z = parseFloat(z_input.value)
            if(z <= 4.99) {
                z += 0.01
                z_input.value = z.toFixed(4)
            }
            """)
        self.z_minus_button.js_on_event('button_click', self.z_minus_callback)
        self.z_plus_button.js_on_event('button_click', self.z_plus_callback)

        self.zreset_button = Button(label='Reset to z_pipe')
        self.zreset_callback = CustomJS(
            args=dict(z_input=self.z_input, metadata=viewer_cds.cds_metadata, ifiberslider=self.ifiberslider),
            code="""
                var ifiber = ifiberslider.value
                var z = metadata.data['Z'][ifiber]
                z_input.value = z.toFixed(4)
            """)
        self.zreset_button.js_on_event('button_click', self.zreset_callback)

        self.z_input_callback = CustomJS(
            args=dict(spectra = viewer_cds.cds_spectra,
                coaddcam_spec = viewer_cds.cds_coaddcam_spec,
                model = viewer_cds.cds_model,
                othermodel = viewer_cds.cds_othermodel,
                metadata = viewer_cds.cds_metadata,
                ifiberslider = self.ifiberslider,
                zslider = self.zslider,
                dzslider = self.dzslider,
                z_input = self.z_input,
                waveframe_buttons = self.waveframe_buttons,
                line_data = viewer_cds.cds_spectral_lines,
                lines = plots.speclines,
                line_labels = plots.specline_labels,
                zlines = plots.zoom_speclines,
                zline_labels = plots.zoom_specline_labels,
                overlap_waves = plots.overlap_waves,
                overlap_bands = plots.overlap_bands,
                fig = plots.fig
                ),
            code="""
                var z = parseFloat(z_input.value)
                if ( z >=-0.1 && z <= 5.0 ) {
                    // update zsliders only if needed (avoid recursive call)
                    z_input.value = parseFloat(z_input.value).toFixed(4)
                    var z1 = Math.floor(z*100) / 100
                    var z2 = z-z1
                    if ( Math.abs(z1-zslider.value) >= 0.01) zslider.value = parseFloat(parseFloat(z1).toFixed(2))
                    if ( Math.abs(z2-dzslider.value) >= 0.0001) dzslider.value = parseFloat(parseFloat(z2).toFixed(4))
                } else {
                    if (z_input.value < -0.1) z_input.value = (-0.1).toFixed(4)
                    if (z_input.value > 5) z_input.value = (5.0).toFixed(4)
                }

                var line_restwave = line_data.data['restwave']
                var ifiber = ifiberslider.value
                var waveshift_lines = (waveframe_buttons.active == 0) ? 1+z : 1 ;
                var waveshift_spec = (waveframe_buttons.active == 0) ? 1 : 1/(1+z) ;

                for(var i=0; i<line_restwave.length; i++) {
                    lines[i].location = line_restwave[i] * waveshift_lines
                    line_labels[i].x = line_restwave[i] * waveshift_lines
                    zlines[i].location = line_restwave[i] * waveshift_lines
                    zline_labels[i].x = line_restwave[i] * waveshift_lines
                }
                if (overlap_bands.length>0) {
                    for (var i=0; i<overlap_bands.length; i++) {
                        overlap_bands[i].left = overlap_waves[i][0] * waveshift_spec
                        overlap_bands[i].right = overlap_waves[i][1] * waveshift_spec
                    }
                }

                function shift_plotwave(cds_spec, waveshift) {
                    var data = cds_spec.data
                    var origwave = data['origwave']
                    var plotwave = data['plotwave']
                    if ( plotwave[0] != origwave[0] * waveshift ) { // Avoid redo calculation if not needed
                        for (var j=0; j<plotwave.length; j++) {
                            plotwave[j] = origwave[j] * waveshift ;
                        }
                        cds_spec.change.emit()
                    }
                }

                for(var i=0; i<spectra.length; i++) {
                    shift_plotwave(spectra[i], waveshift_spec)
                }
                if (coaddcam_spec) shift_plotwave(coaddcam_spec, waveshift_spec)

                // Update model wavelength array
                // NEW : don't shift model if othermodel is there
                if (othermodel) {
                    var zref = othermodel.data['zref'][0]
                    var waveshift_model = (waveframe_buttons.active == 0) ? (1+z)/(1+zref) : 1/(1+zref) ;
                    shift_plotwave(othermodel, waveshift_model)
                } else if (model) {
                    var zfit = 0.0
                    if(metadata.data['Z'] !== undefined) {
                        zfit = metadata.data['Z'][ifiber]
                    }
                    var waveshift_model = (waveframe_buttons.active == 0) ? (1+z)/(1+zfit) : 1/(1+zfit) ;
                    shift_plotwave(model, waveshift_model)
                }
            """)
        self.z_input.js_on_change('value', self.z_input_callback)
        self.waveframe_buttons.js_on_click(self.z_input_callback)

        self.plotrange_callback = CustomJS(
            args = dict(
                z_input=self.z_input,
                waveframe_buttons=self.waveframe_buttons,
                fig=plots.fig,
            ),
            code="""
            var z =  parseFloat(z_input.value)
            // Observer Frame
            if(waveframe_buttons.active == 0) {
                fig.x_range.start = fig.x_range.start * (1+z)
                fig.x_range.end = fig.x_range.end * (1+z)
            } else {
                fig.x_range.start = fig.x_range.start / (1+z)
                fig.x_range.end = fig.x_range.end / (1+z)
            }
            """
        )
        self.waveframe_buttons.js_on_click(self.plotrange_callback) # TODO: for record: is this related to waveframe bug? : 2 callbakcs for same click...


    def add_oii_widgets(self, plots):
        #------
        #- Zoom on the OII doublet TODO mv js code to other file
        # TODO: is there another trick than using a cds to pass the "oii_saveinfo" ?
        # TODO: optimize smoothing for autozoom (current value: 0)
        cds_oii_saveinfo = ColumnDataSource(
            {'xmin':[plots.fig.x_range.start], 'xmax':[plots.fig.x_range.end], 'nsmooth':[self.smootherslider.value]})
        self.oii_zoom_button = Button(label="OII-zoom", button_type="default")
        self.oii_zoom_callback = CustomJS(
            args = dict(z_input=self.z_input, fig=plots.fig, smootherslider=self.smootherslider,
                       cds_oii_saveinfo=cds_oii_saveinfo),
            code = """
            // Save previous setting (for the "Undo" button)
            cds_oii_saveinfo.data['xmin'] = [fig.x_range.start]
            cds_oii_saveinfo.data['xmax'] = [fig.x_range.end]
            cds_oii_saveinfo.data['nsmooth'] = [smootherslider.value]
            // Center on the middle of the redshifted OII doublet (vaccum)
            var z = parseFloat(z_input.value)
            fig.x_range.start = 3728.48 * (1+z) - 100
            fig.x_range.end = 3728.48 * (1+z) + 100
            // No smoothing (this implies a call to update_plot)
            smootherslider.value = 0
            """)
        self.oii_zoom_button.js_on_event('button_click', self.oii_zoom_callback)

        self.oii_undo_button = Button(label="Undo OII-zoom", button_type="default")
        self.oii_undo_callback = CustomJS(
            args = dict(fig=plots.fig, smootherslider=self.smootherslider, cds_oii_saveinfo=cds_oii_saveinfo),
            code = """
            fig.x_range.start = cds_oii_saveinfo.data['xmin'][0]
            fig.x_range.end = cds_oii_saveinfo.data['xmax'][0]
            smootherslider.value = cds_oii_saveinfo.data['nsmooth'][0]
            """)
        self.oii_undo_button.js_on_event('button_click', self.oii_undo_callback)


    def add_coaddcam(self, plots):
        #-----
        #- Highlight individual-arm or camera-coadded spectra
        coaddcam_labels = ["Camera-coadded", "Single-arm"]
        self.coaddcam_buttons = RadioButtonGroup(labels=coaddcam_labels, active=0)
        self.coaddcam_callback = CustomJS(
            args = dict(coaddcam_buttons = self.coaddcam_buttons,
                        list_lines=[plots.data_lines, plots.noise_lines,
                                    plots.zoom_data_lines, plots.zoom_noise_lines],
                        alpha_discrete = plots.alpha_discrete,
                        overlap_bands = plots.overlap_bands,
                        alpha_overlapband = plots.alpha_overlapband),
            code="""
            var n_lines = list_lines[0].length
            for (var i=0; i<n_lines; i++) {
                var new_alpha = 1
                if (coaddcam_buttons.active == 0 && i<n_lines-1) new_alpha = alpha_discrete
                if (coaddcam_buttons.active == 1 && i==n_lines-1) new_alpha = alpha_discrete
                for (var j=0; j<list_lines.length; j++) {
                    list_lines[j][i].glyph.line_alpha = new_alpha
                }
            }
            var new_alpha = 0
            if (coaddcam_buttons.active == 0) new_alpha = alpha_overlapband
            for (var j=0; j<overlap_bands.length; j++) {
                    overlap_bands[j].fill_alpha = new_alpha
            }
            """
        )
        self.coaddcam_buttons.js_on_click(self.coaddcam_callback)
    
    
    def add_metadata_tables(self, viewer_cds, show_zcat=True, template_dicts=None,
                           top_metadata=['TARGETID', 'EXPID']):
        """ Display object-related informations
                top_metadata: metadata to be highlighted in table_a
            
            Note: "short" CDS, with a single row, are used to fill these bokeh tables.
            When changing object, js code modifies these short CDS so that tables are updated.  
        """

        #- Sorted list of potential metadata:
        metadata_to_check = ['TARGETID', 'HPXPIXEL', 'TILEID', 'COADD_NUMEXP', 'COADD_EXPTIME', 
                             'NIGHT', 'EXPID', 'FIBER', 'CAMERA', 'MORPHTYPE']
        metadata_to_check += [ ('mag_'+x) for x in viewer_cds.phot_bands ]
        table_keys = []
        for key in metadata_to_check:
            if key in viewer_cds.cds_metadata.data.keys():
                table_keys.append(key)
            if 'NUM_'+key in viewer_cds.cds_metadata.data.keys():
                for prefix in ['FIRST','LAST','NUM']:
                    table_keys.append(prefix+'_'+key)
                    if key in top_metadata:
                        top_metadata.append(prefix+'_'+key)
        
        #- Table a: "top metadata"
        table_a_keys = [ x for x in table_keys if x in top_metadata ]
        self.shortcds_table_a, self.table_a = _metadata_table(table_a_keys, viewer_cds, table_width=600, 
                                                              shortcds_name='shortcds_table_a', selectable=True)
        #- Table b: Targeting information
        self.shortcds_table_b, self.table_b = _metadata_table(['Targeting masks'], viewer_cds, table_width=self.plot_widget_width,
                                                              shortcds_name='shortcds_table_b', selectable=True)
        #- Table(s) c/d : Other information (imaging, etc.)
        remaining_keys = [ x for x in table_keys if x not in top_metadata ]
        if len(remaining_keys) > 7:
            table_c_keys = remaining_keys[0:len(remaining_keys)//2]
            table_d_keys = remaining_keys[len(remaining_keys)//2:]
        else:
            table_c_keys = remaining_keys
            table_d_keys = None
        self.shortcds_table_c, self.table_c = _metadata_table(table_c_keys, viewer_cds, table_width=self.plot_widget_width,
                                                             shortcds_name='shortcds_table_c', selectable=False)
        if table_d_keys is None:
            self.shortcds_table_d, self.table_d = None, None
        else:
            self.shortcds_table_d, self.table_d = _metadata_table(table_d_keys, viewer_cds, table_width=self.plot_widget_width,
                                                                 shortcds_name='shortcds_table_d', selectable=False)

        #- Table z: redshift fitting information
        if show_zcat is not None :
            if template_dicts is not None : # Add other best fits
                fit_results = template_dicts[1]
                # Case of DeltaChi2 : compute it from Chi2s
                #    The "DeltaChi2" in rr fits is between best fits for a given (spectype,subtype)
                #    Convention: DeltaChi2 = -1 for the last fit.
                chi2s = fit_results['CHI2'][0]
                full_deltachi2s = np.zeros(len(chi2s))-1
                full_deltachi2s[:-1] = chi2s[1:]-chi2s[:-1]
                cdsdata = dict(Nfit = np.arange(1,len(chi2s)+1),
                                SPECTYPE = fit_results['SPECTYPE'][0],  # [0:num_best_fits] (if we want to restrict... TODO?)
                                SUBTYPE = fit_results['SUBTYPE'][0],
                                Z = [ "{:.4f}".format(x) for x in fit_results['Z'][0] ],
                                ZERR = [ "{:.4f}".format(x) for x in fit_results['ZERR'][0] ],
                                ZWARN = fit_results['ZWARN'][0],
                                CHI2 = [ "{:.1f}".format(x) for x in fit_results['CHI2'][0] ],
                                DELTACHI2 = [ "{:.1f}".format(x) for x in full_deltachi2s ])
                self.shortcds_table_z = ColumnDataSource(cdsdata, name='shortcds_table_z')
                columns_table_z = [ TableColumn(field=x, title=t, width=w) for x,t,w in [ ('Nfit','Nfit',5), ('SPECTYPE','SPECTYPE',70), ('SUBTYPE','SUBTYPE',60), ('Z','Z',50) , ('ZERR','ZERR',50), ('ZWARN','ZWARN',50), ('DELTACHI2','Δχ2(N+1/N)',70)] ]
                self.table_z = DataTable(source=self.shortcds_table_z, columns=columns_table_z,
                                         selectable=False, index_position=None, width=self.plot_widget_width)
                self.table_z.height = 3 * self.table_z.row_height
            else :
                self.shortcds_table_z, self.table_z = _metadata_table(viewer_cds.zcat_keys, viewer_cds,
                                    table_width=self.plot_widget_width, shortcds_name='shortcds_table_z', selectable=False)
        else :
            self.table_z = Div(text="Not available ")
            self.shortcds_table_z = None


    def add_specline_toggles(self, viewer_cds, plots):
        #-----
        #- Toggle lines
        self.speclines_button_group = CheckboxButtonGroup(
                labels=["Emission lines", "Absorption lines"], active=[])
        self.majorline_checkbox = CheckboxGroup(
                labels=['Show only major lines'], active=[])

        self.speclines_callback = CustomJS(
            args = dict(line_data = viewer_cds.cds_spectral_lines,
                        lines = plots.speclines,
                        line_labels = plots.specline_labels,
                        zlines = plots.zoom_speclines,
                        zline_labels = plots.zoom_specline_labels,
                        lines_button_group = self.speclines_button_group,
                        majorline_checkbox = self.majorline_checkbox),
            code="""
            var show_emission = false
            var show_absorption = false
            if (lines_button_group.active.indexOf(0) >= 0) {  // index 0=Emission in active list
                show_emission = true
            }
            if (lines_button_group.active.indexOf(1) >= 0) {  // index 1=Absorption in active list
                show_absorption = true
            }

            for(var i=0; i<lines.length; i++) {
                if ( !(line_data.data['major'][i]) && (majorline_checkbox.active.indexOf(0)>=0) ) {
                    lines[i].visible = false
                    line_labels[i].visible = false
                    zlines[i].visible = false
                    zline_labels[i].visible = false
                } else if (line_data.data['emission'][i]) {
                    lines[i].visible = show_emission
                    line_labels[i].visible = show_emission
                    zlines[i].visible = show_emission
                    zline_labels[i].visible = show_emission
                } else {
                    lines[i].visible = show_absorption
                    line_labels[i].visible = show_absorption
                    zlines[i].visible = show_absorption
                    zline_labels[i].visible = show_absorption
                }
            }
            """
        )
        self.speclines_button_group.js_on_click(self.speclines_callback)
        self.majorline_checkbox.js_on_click(self.speclines_callback)


    def add_model_select(self, viewer_cds, template_dicts, num_approx_fits, with_full_2ndfit=True):
        #------
        #- Select secondary model to display
        model_options = ['Best fit', '2nd best fit']
        for i in range(1,1+num_approx_fits) :
            ith = 'th'
            if i==1 : ith='st'
            if i==2 : ith='nd'
            if i==3 : ith='rd'
            model_options.append(str(i)+ith+' fit (approx)')
        if with_full_2ndfit is False :
            model_options.remove('2nd best fit')
        for std_template in ['QSO', 'GALAXY', 'STAR'] :
            model_options.append('STD '+std_template)
        self.model_select = Select(value=model_options[0], title="Other model (dashed curve):", options=model_options)
        model_select_code = self.js_files["interp_grid.js"] + self.js_files["smooth_data.js"] + self.js_files["select_model.js"]
        self.model_select_callback = CustomJS(
            args = dict(ifiberslider = self.ifiberslider,
                        model_select = self.model_select,
                        fit_templates=template_dicts[0],
                        cds_othermodel = viewer_cds.cds_othermodel,
                        cds_model_2ndfit = viewer_cds.cds_model_2ndfit,
                        cds_model = viewer_cds.cds_model,
                        fit_results=template_dicts[1],
                        std_templates=template_dicts[2],
                        median_spectra = viewer_cds.cds_median_spectra,
                        smootherslider = self.smootherslider,
                        z_input = self.z_input,
                        cds_metadata = viewer_cds.cds_metadata),
                        code = model_select_code)
        self.model_select.js_on_change('value', self.model_select_callback)


    def add_update_plot_callback(self, viewer_cds, plots, vi_widgets, template_dicts):
        #-----
        #- Main js code to update plots
        update_plot_code = (self.js_files["adapt_plotrange.js"] + self.js_files["interp_grid.js"] +
                            self.js_files["smooth_data.js"] + self.js_files["coadd_brz_cameras.js"] +
                            self.js_files["update_plot.js"])
        # TMP handling of template_dicts
        the_fit_results = None if template_dicts is None else template_dicts[1] # dirty
        self.update_plot_callback = CustomJS(
            args = dict(
                spectra = viewer_cds.cds_spectra,
                coaddcam_spec = viewer_cds.cds_coaddcam_spec,
                model = viewer_cds.cds_model,
                othermodel = viewer_cds.cds_othermodel,
                model_2ndfit = viewer_cds.cds_model_2ndfit,
                metadata = viewer_cds.cds_metadata,
                fit_results = the_fit_results,
                shortcds_table_z = self.shortcds_table_z,
                shortcds_table_a = self.shortcds_table_a,
                shortcds_table_b = self.shortcds_table_b,
                shortcds_table_c = self.shortcds_table_c,
                shortcds_table_d = self.shortcds_table_d,
                ifiberslider = self.ifiberslider,
                smootherslider = self.smootherslider,
                z_input = self.z_input,
                fig = plots.fig,
                imfig_source = plots.imfig_source,
                imfig_urls = plots.imfig_urls,
                model_select = self.model_select,
                vi_comment_input = vi_widgets.vi_comment_input,
                vi_std_comment_select = vi_widgets.vi_std_comment_select,
                vi_name_input = vi_widgets.vi_name_input,
                vi_quality_input = vi_widgets.vi_quality_input,
                vi_quality_labels = vi_widgets.vi_quality_labels,
                vi_issue_input = vi_widgets.vi_issue_input,
                vi_z_input = vi_widgets.vi_z_input,
                vi_category_select = vi_widgets.vi_category_select,
                vi_issue_slabels = vi_widgets.vi_issue_slabels
                ),
            code = update_plot_code
        )
        self.smootherslider.js_on_change('value', self.update_plot_callback)
        self.ifiberslider.js_on_change('value', self.update_plot_callback)
Exemple #16
0
def create_final_box_plot(method, region):
    df, p1 = ([] for _ in range(2))
    masks = [3,4,5,6]
    for imask in masks:
        with np.load('numpy_files/arrays_reg'+region+'_'+str(imask)+samples+'_'+method+'.npz') as f:
            df.append(pd.DataFrame(f['response_eta']).transpose())
            df[-1].columns = ['response', 'eta']

    #data sources
    data_limit = 64800
    source1 = ColumnDataSource(data=dict(x_mask3=df[0].eta[:data_limit], x_mask4=df[1].eta[:data_limit], x_mask5=df[2].eta[:data_limit], x_mask6=df[3].eta[:data_limit], 
                                         y_mask3=df[0].response[:data_limit], y_mask4=df[1].response[:data_limit], y_mask5=df[2].response[:data_limit], y_mask6=df[3].response[:data_limit],
                                         c_mask3=[colors[0]]*data_limit, c_mask4=[colors[1]]*data_limit, c_mask5=[colors[2]]*data_limit, c_mask6=[colors[3]]*data_limit))
    source1_v2 = ColumnDataSource(data=dict(x_mask3=df[0].eta[:data_limit], x_mask4=df[1].eta[:data_limit], x_mask5=df[2].eta[:data_limit], x_mask6=df[3].eta[:data_limit],
                                            y_mask3=df[0].response[:data_limit], y_mask4=df[1].response[:data_limit], y_mask5=df[2].response[:data_limit], y_mask6=df[3].response[:data_limit],
                                            c_mask3=[colors[0]]*data_limit, c_mask4=[colors[1]]*data_limit, c_mask5=[colors[2]]*data_limit, c_mask6=[colors[3]]*data_limit))
    source_vlines = ColumnDataSource(data=dict(etamin=[start], etamax=[end]))
    source_outliers = ColumnDataSource(data=dict(x_mask3=[1]*data_limit, x_mask4=[2]*data_limit, x_mask5=[3]*data_limit, x_mask6=[4]*data_limit, 
                                                 y_mask3=df[0].response[:data_limit]+100., y_mask4=df[1].response[:data_limit]+100., y_mask5=df[2].response[:data_limit]+100., y_mask6=df[3].response[:data_limit]+100.))
    source2 = ColumnDataSource(data=dict(x_mask3=[1],x_mask4=[2],x_mask5=[3],x_mask6=[4], 
                                         y1_mask3=[0.],y1_mask4=[0.],y1_mask5=[0.],y1_mask6=[0.],
                                         y2_mask3=[0.], y2_mask4=[0.], y2_mask5=[0.], y2_mask6=[0.],
                                         height1_mask3=[.6],height1_mask4=[.6],height1_mask5=[.6],height1_mask6=[.6],
                                         height2_mask3=[.6],height2_mask4=[.6],height2_mask5=[.6],height2_mask6=[.6],
                                         q1_mask3=[-.25],q1_mask4=[-.25],q1_mask5=[-.25],q1_mask6=[-.25], 
                                         q3_mask3=[-.25],q3_mask4=[-.25],q3_mask5=[-.25],q3_mask6=[-.25], 
                                         upper_mask3=[-.6],upper_mask4=[-.6],upper_mask5=[-.6],upper_mask6=[-.6],
                                         lower_mask3=[.6], lower_mask4=[.6], lower_mask5=[.6], lower_mask6=[.6]))

    #1st figure
    plot_options = dict(plot_height=500, plot_width=350, y_range=y_range, tools="wheel_zoom,box_zoom,box_select,pan,reset", output_backend="webgl")
    for imask in masks:
        p1.append(figure(title="Eta vs. Response          Mask "+str(imask), **plot_options))
        p1[-1].circle('x_mask'+str(imask), 'y_mask'+str(imask), color='c_mask'+str(imask), size=2, alpha=0.4, source=source1)
        p1[-1].segment(x0='etamin', y0=y_range[0], x1='etamin', y1=y_range[1], source=source_vlines, line_color='black', line_width=2)
        p1[-1].segment(x0='etamax', y0=y_range[0], x1='etamax', y1=y_range[1], source=source_vlines, line_color='black', line_width=2)

    #2nd figure
    plot2_options = dict(plot_height=300, plot_width=500, x_range=x_range, y_range=(y_range[0]-0.3,y_range[1]+0.3), tools="wheel_zoom,box_zoom,pan,reset", output_backend="webgl")
    radius = dict({'1':'1.3cm', '2':'2.6cm', '3':'5.3cm'})
    methoddict = dict({'nocorr':'No correction', 'corr_ed':'Shower leakage', 'corr_fineeta':'Brute force'})
    p2 = figure(title="Signal integration radius: "+radius[region]+'        Method: '+methoddict[method], **plot2_options)
    p2.xaxis.visible = False
    box_options = dict(line_color='black', line_width=2, source=source2)
    for imask in masks:
        #boxes
        p2.rect(x='x_mask'+str(imask), y='y1_mask'+str(imask), width=0.9, height='height1_mask'+str(imask), color=colors[imask-masks[0]], legend='Mask '+str(imask), **box_options)
        p2.rect(x='x_mask'+str(imask), y='y2_mask'+str(imask), width=0.9, height='height2_mask'+str(imask), color=colors[imask-masks[0]], **box_options)
        #segments
        p2.segment(x0='x_mask'+str(imask), y0='q3_mask'+str(imask), x1='x_mask'+str(imask), y1='upper_mask'+str(imask), **box_options)
        p2.segment(x0='x_mask'+str(imask), y0='q1_mask'+str(imask), x1='x_mask'+str(imask), y1='lower_mask'+str(imask), **box_options)
        #whiskers
        p2.rect(x='x_mask'+str(imask), y='upper_mask'+str(imask), width=0.1, height=0.005, **box_options)
        p2.rect(x='x_mask'+str(imask), y='lower_mask'+str(imask), width=0.1, height=0.005, **box_options)
        #outliers
        p2.diamond(x='x_mask'+str(imask), y='y_mask'+str(imask), line_color=colors_diamonds[imask-masks[0]], size=4, source=source_outliers)
        #dashed line
        p2.segment(x0=x_range[0]+0.5, y0=0., x1=x_range[1]-0.5, y1=0., line_color='black', line_dash='dashed')

    #Dynamic behaviour (the model that triggers the callback is called 'cb_obj')
    callback = CustomJS(args=dict(source1=source1, source1_v2=source1_v2, source2=source2, 
                                  source_vlines=source_vlines, source_outliers=source_outliers), code="""
            var etamin = cb_obj.value;
            var delta_eta = 0.03;
            var data1 = source1.data;
            var data1_v2 = source1_v2.data;
            var data2 = source2.data;
            var data_lines = source_vlines.data;
            var data_out = source_outliers.data;
            data_lines['etamin'][0] = etamin;
            data_lines['etamax'][0] = etamin+delta_eta;

            data1_v2['x_mask3'] = [];
            data1_v2['x_mask4'] = [];
            data1_v2['x_mask5'] = [];
            data1_v2['x_mask6'] = [];
            data1_v2['y_mask3'] = [];
            data1_v2['y_mask4'] = [];
            data1_v2['y_mask5'] = [];
            data1_v2['y_mask6'] = [];
            for (var i = 0; i < data1['x_mask3'].length; i++) {
                if (data1['x_mask3'][i] > etamin && data1['x_mask3'][i] <= etamin+delta_eta) {
                   data1_v2['x_mask3'].push(data1['x_mask3'][i]);
                   data1_v2['y_mask3'].push(data1['y_mask3'][i]);
                }
            }
            for (var i = 0; i < data1['x_mask4'].length; i++) {
                if (data1['x_mask4'][i] > etamin && data1['x_mask4'][i] <= etamin+delta_eta) {
                   data1_v2['x_mask4'].push(data1['x_mask4'][i]);
                   data1_v2['y_mask4'].push(data1['y_mask4'][i]);
                }
            }
            for (var i = 0; i < data1['x_mask5'].length; i++) {
                if (data1['x_mask5'][i] > etamin && data1['x_mask5'][i] <= etamin+delta_eta) {
                   data1_v2['x_mask5'].push(data1['x_mask5'][i]);
                   data1_v2['y_mask5'].push(data1['y_mask5'][i]);
                }
            }
            for (var i = 0; i < data1['x_mask6'].length; i++) {
                if (data1['x_mask6'][i] > etamin && data1['x_mask6'][i] <= etamin+delta_eta) {
                   data1_v2['x_mask6'].push(data1['x_mask6'][i]);
                   data1_v2['y_mask6'].push(data1['y_mask6'][i]);
                }
            }

           //define quantiles and related quantities
           var data_tmp_mask3 = data1_v2['y_mask3'];
           data_tmp_mask3.sort( function(a,b) {return a - b;} );
           var l1_3 = Math.floor( (data_tmp_mask3.length) * 0.25);
           var l2_3 = Math.floor( (data_tmp_mask3.length) * 0.50);
           var l3_3 = Math.floor( (data_tmp_mask3.length) * 0.75);

           var data_tmp_mask4 = data1_v2['y_mask4'];
           data_tmp_mask4.sort( function(a,b) {return a - b;} );
           var l1_4 = Math.floor( (data_tmp_mask4.length) * 0.25);
           var l2_4 = Math.floor( (data_tmp_mask4.length) * 0.50);
           var l3_4 = Math.floor( (data_tmp_mask4.length) * 0.75);

           var data_tmp_mask5 = data1_v2['y_mask5'];
           data_tmp_mask5.sort( function(a,b) {return a - b;} );
           var l1_5 = Math.floor( (data_tmp_mask5.length) * 0.25);
           var l2_5 = Math.floor( (data_tmp_mask5.length) * 0.50);
           var l3_5 = Math.floor( (data_tmp_mask5.length) * 0.75);

           var data_tmp_mask6 = data1_v2['y_mask6'];
           data_tmp_mask6.sort( function(a,b) {return a - b;} );
           var l1_6 = Math.floor( (data_tmp_mask6.length) * 0.25);
           var l2_6 = Math.floor( (data_tmp_mask6.length) * 0.50);
           var l3_6 = Math.floor( (data_tmp_mask6.length) * 0.75);

           var upper_3;
           var lower_3;
           var upper_4;
           var lower_4;
           var upper_5;
           var lower_5;
           var upper_6;
           var lower_6;

           if(data_tmp_mask3.length % 2) 
           {
              var q1 = data_tmp_mask3[l1_3];
              var q2 = data_tmp_mask3[l2_3];
              var q3 = data_tmp_mask3[l3_3];
              data2['y1_mask3'][0] = (q2+q1)/2;
              data2['y2_mask3'][0] = (q3+q2)/2;
              data2['height1_mask3'][0] = q2-q1;
              data2['height2_mask3'][0] = q3-q2;
              data2['q1_mask3'][0] = q1;
              data2['q3_mask3'][0] = q3;
              upper_3 = q3+1.5*(q3-q1);
              lower_3 = q1-1.5*(q3-q1);
              data2['upper_mask3'][0] = upper_3;
              data2['lower_mask3'][0] = lower_3;
           }
           else
           {
              var q1 = (data_tmp_mask3[l1_3-1]+data_tmp_mask3[l1_3])/2.0;
              var q2 = (data_tmp_mask3[l2_3-1]+data_tmp_mask3[l2_3])/2.0;
              var q3 = (data_tmp_mask3[l3_3-1]+data_tmp_mask3[l3_3])/2.0;
              data2['y1_mask3'][0] = (q2+q1)/2;
              data2['y2_mask3'][0] = (q3+q2)/2;
              data2['height1_mask3'][0] = q2-q1;
              data2['height2_mask3'][0] = q3-q2;
              data2['q1_mask3'][0] = q1;
              data2['q3_mask3'][0] = q3;
              upper_3 = q3+1.5*(q3-q1);
              lower_3 = q1-1.5*(q3-q1);
              data2['upper_mask3'][0] = upper_3;
              data2['lower_mask3'][0] = lower_3;
           }

           if(data_tmp_mask4.length % 2) 
           {
              var q1 = data_tmp_mask4[l1_4];
              var q2 = data_tmp_mask4[l2_4];
              var q3 = data_tmp_mask4[l3_4];
              data2['y1_mask4'][0] = (q2+q1)/2;
              data2['y2_mask4'][0] = (q3+q2)/2;
              data2['height1_mask4'][0] = q2-q1;
              data2['height2_mask4'][0] = q3-q2;
              data2['q1_mask4'][0] = q1;
              data2['q3_mask4'][0] = q3;
              upper_4 = q3+1.5*(q3-q1);
              lower_4 = q1-1.5*(q3-q1);
              data2['upper_mask4'][0] = upper_4;
              data2['lower_mask4'][0] = lower_4;
           }
           else
           {
              var q1 = (data_tmp_mask4[l1_4-1]+data_tmp_mask4[l1_4])/2.0;
              var q2 = (data_tmp_mask4[l2_4-1]+data_tmp_mask4[l2_4])/2.0;
              var q3 = (data_tmp_mask4[l3_4-1]+data_tmp_mask4[l3_4])/2.0;
              data2['y1_mask4'][0] = (q2+q1)/2;
              data2['y2_mask4'][0] = (q3+q2)/2;
              data2['height1_mask4'][0] = q2-q1;
              data2['height2_mask4'][0] = q3-q2;
              data2['q1_mask4'][0] = q1;
              data2['q3_mask4'][0] = q3;
              upper_4 = q3+1.5*(q3-q1);
              lower_4 = q1-1.5*(q3-q1);
              data2['upper_mask4'][0] = upper_4;
              data2['lower_mask4'][0] = lower_4;
           }

           if(data_tmp_mask5.length % 2) 
           {
              var q1 = data_tmp_mask5[l1_5];
              var q2 = data_tmp_mask5[l2_5];
              var q3 = data_tmp_mask5[l3_5];
              data2['y1_mask5'][0] = (q2+q1)/2;
              data2['y2_mask5'][0] = (q3+q2)/2;
              data2['height1_mask5'][0] = q2-q1;
              data2['height2_mask5'][0] = q3-q2;
              data2['q1_mask5'][0] = q1;
              data2['q3_mask5'][0] = q3;
              upper_5 = q3+1.5*(q3-q1);
              lower_5 = q1-1.5*(q3-q1);
              data2['upper_mask5'][0] = upper_5;
              data2['lower_mask5'][0] = lower_5;
            }
            else
            {
              q1 = (data_tmp_mask5[l1_5-1]+data_tmp_mask5[l1_5])/2.0;
              q2 = (data_tmp_mask5[l2_5-1]+data_tmp_mask5[l2_5])/2.0;
              q3 = (data_tmp_mask5[l3_5-1]+data_tmp_mask5[l3_5])/2.0;
              data2['y1_mask5'][0] = (q2+q1)/2;
              data2['y2_mask5'][0] = (q3+q2)/2;
              data2['height1_mask5'][0] = q2-q1;
              data2['height2_mask5'][0] = q3-q2;
              data2['q1_mask5'][0] = q1;
              data2['q3_mask5'][0] = q3;
              upper_5 = q3+1.5*(q3-q1);
              lower_5 = q1-1.5*(q3-q1);
              data2['upper_mask5'][0] = upper_5;
              data2['lower_mask5'][0] = lower_5;
            }

           if(data_tmp_mask6.length % 2) 
           {
              var q1 = data_tmp_mask6[l1_6];
              var q2 = data_tmp_mask6[l2_6];
              var q3 = data_tmp_mask6[l3_6];
              data2['y1_mask6'][0] = (q2+q1)/2;
              data2['y2_mask6'][0] = (q3+q2)/2;
              data2['height1_mask6'][0] = q2-q1;
              data2['height2_mask6'][0] = q3-q2;
              data2['q1_mask6'][0] = q1;
              data2['q3_mask6'][0] = q3;
              upper_6 = q3+1.5*(q3-q1);
              lower_6 = q1-1.5*(q3-q1);
              data2['upper_mask6'][0] = upper_6;
              data2['lower_mask6'][0] = lower_6;
           }
           else
           {
              var q1 = (data_tmp_mask6[l1_6-1]+data_tmp_mask6[l1_6])/2.0;
              var q2 = (data_tmp_mask6[l2_6-1]+data_tmp_mask6[l2_6])/2.0;
              var q3 = (data_tmp_mask6[l3_6-1]+data_tmp_mask6[l3_6])/2.0;
              data2['y1_mask6'][0] = (q2+q1)/2;
              data2['y2_mask6'][0] = (q3+q2)/2;
              data2['height1_mask6'][0] = q2-q1;
              data2['height2_mask6'][0] = q3-q2;
              data2['q1_mask6'][0] = q1;
              data2['q3_mask6'][0] = q3;
              upper_6 = q3+1.5*(q3-q1);
              lower_6 = q1-1.5*(q3-q1);
              data2['upper_mask6'][0] = upper_6;
              data2['lower_mask6'][0] = lower_6;
           }

           //avoid putting the whiskers beyond the data points
           var currMax_3 = Math.max.apply(Math, data1_v2['y_mask3']);
           var currMin_3 = Math.min.apply(Math, data1_v2['y_mask3']);
           var currMax_4 = Math.max.apply(Math, data1_v2['y_mask4']);
           var currMin_4 = Math.min.apply(Math, data1_v2['y_mask4']);
           var currMax_5 = Math.max.apply(Math, data1_v2['y_mask5']);
           var currMin_5 = Math.min.apply(Math, data1_v2['y_mask5']);
           var currMax_6 = Math.max.apply(Math, data1_v2['y_mask6']);
           var currMin_6 = Math.min.apply(Math, data1_v2['y_mask6']);

           if(currMax_3 > upper_3)
              data2['upper_mask3'][0] = upper_3; 
           else
              data2['upper_mask3'][0] = currMax_3; 
           if(currMin_3 < lower_3)
              data2['lower_mask3'][0] = lower_3; 
           else
              data2['lower_mask3'][0] = currMin_3; 

           if(currMax_4 > upper_4)
              data2['upper_mask4'][0] = upper_4; 
           else
              data2['upper_mask4'][0] = currMax_4; 
           if(currMin_4 < lower_4)
              data2['lower_mask4'][0] = lower_4; 
           else
              data2['lower_mask4'][0] = currMin_4; 

           if(currMax_5 > upper_5)
              data2['upper_mask5'][0] = upper_5; 
           else
              data2['upper_mask5'][0] = currMax_5; 
            if(currMin_5 < lower_5)
              data2['lower_mask5'][0] = lower_5; 
           else
              data2['lower_mask5'][0] = currMin_5; 

           if(currMax_6 > upper_6)
               data2['upper_mask6'][0] = upper_6; 
           else
               data2['upper_mask6'][0] = currMax_6; 
           if(currMin_6 < lower_6)
               data2['lower_mask6'][0] = lower_6; 
           else
               data2['lower_mask6'][0] = currMin_6; 

           //place outliers
           data_out['y_mask3'] = [];
           data_out['y_mask4'] = [];
           data_out['y_mask5'] = [];
           data_out['y_mask6'] = [];
           for (var i = 0; i < data1_v2['x_mask3'].length; i++) {
             if (data1_v2['y_mask3'][i] > upper_3 || data1_v2['y_mask3'][i] < lower_3)
                  data_out['y_mask3'].push(data1_v2['y_mask3'][i]);
           }
           for (var i = 0; i < data1_v2['x_mask4'].length; i++) {
             if (data1_v2['y_mask4'][i] > upper_4 || data1_v2['y_mask4'][i] < lower_4)
                  data_out['y_mask4'].push(data1_v2['y_mask4'][i]);
           }
           for (var i = 0; i < data1_v2['x_mask5'].length; i++) {
             if (data1_v2['y_mask5'][i] > upper_5 || data1_v2['y_mask5'][i] < lower_5)
                  data_out['y_mask5'].push(data1_v2['y_mask5'][i]);
           }
           for (var i = 0; i < data1_v2['x_mask6'].length; i++) {
             if (data1_v2['y_mask6'][i] > upper_6 || data1_v2['y_mask6'][i] < lower_6)
                  data_out['y_mask6'].push(data1_v2['y_mask6'][i]);
           }

            //update sources
            source1.change.emit();
            source1_v2.change.emit();
            source2.change.emit();
            source_vlines.change.emit();
            source_outliers.change.emit();
        """)
    slider = Slider(start=start, end=end, value=start, step=.0001, title="Pseudorapidity (left bin edge)")
    slider.js_on_change('value', callback)
    return slider, p2
Exemple #17
0
def plot(numFrac, numPF, args):
    B0 = 3  # Tesla
    mapper = LinearColorMapper(palette='Spectral10', low=0, high=1)
    colorBar = ColorBar(color_mapper=mapper, location=(0, 0))
    acquistionTimes = np.arange(start=float(args.tmin),
                                stop=float(args.tmax),
                                step=float(args.dt))
    numTa = len(acquistionTimes)
    partialFourierFactors = np.linspace(start=0.5, stop=1, num=numPF)
    dephasingTimes = np.empty(shape=(numPF, numFrac, numTa, 2),
                              dtype=np.float32)
    firstEchoFractions = np.linspace(start=0, stop=1, num=numFrac)
    NSA = np.empty(shape=(numPF, numFrac, numTa, 2),
                   dtype=np.float32)  # NSA_ss [weighted, unweighted]

    for nt, ta in enumerate(acquistionTimes):
        for nPF, PF in enumerate(partialFourierFactors):
            for nf, f in enumerate(firstEchoFractions):
                dephasingTimes[nPF, nf,
                               nt, :] = getDephasingTimes(ta / 1.0e3, PF, f)
                weights = weightsFromFraction(f)
                NSA[nPF, nf, nt, 0] = np.reciprocal(
                    weightedCrbTwoEchoes(B0, dephasingTimes[nPF, nf, nt, :],
                                         weights)[2])  # Weighted NSA_ss
                NSA[nPF, nf, nt, 1] = np.reciprocal(
                    weightedCrbTwoEchoes(
                        B0, dephasingTimes[nPF, nf, nt, :],
                        weightsFromFraction(.5))[2])  # Unweighted NSA_ss
        print(100. * (nt + 1) / numTa)

    pWeighted = figure(height=350,
                       width=350,
                       toolbar_location=None,
                       title='Weighted NSA_ss (3T)')
    pUnWeighted = figure(height=350,
                         width=350,
                         toolbar_location=None,
                         title='Unweighted NSA_ss (3T)')
    pUnWeighted.add_layout(colorBar, 'right')
    CDSimages = [
        ColumnDataSource({'imageData': [NSA[:, :, -1, 0]]}),
        ColumnDataSource({'imageData': [NSA[:, :, -1, 1]]})
    ]
    pWeighted.image(image='imageData',
                    x=0,
                    y=0,
                    dw=numFrac,
                    dh=numPF,
                    color_mapper=mapper,
                    source=CDSimages[0])
    pUnWeighted.image(image='imageData',
                      x=0,
                      y=0,
                      dw=numFrac,
                      dh=numPF,
                      color_mapper=mapper,
                      source=CDSimages[1])

    for p in [pWeighted, pUnWeighted]:
        p.xaxis.ticker = [
            0, (numFrac - 1) / 4, (numFrac - 1) / 2, 3 * (numFrac - 1) / 4,
            numFrac - 1
        ]
        p.xaxis.major_label_overrides = {
            0: '0',
            (numFrac - 1) / 4: '.25',
            (numFrac - 1) / 2: '.5',
            3 * (numFrac - 1) / 4: '.75',
            numFrac - 1: '1'
        }
        p.yaxis.ticker = [0, (numPF - 1) / 2, numPF - 1]
        p.yaxis.major_label_overrides = {
            0: '0.5',
            (numPF - 1) / 2: '.75',
            numPF - 1: '1.0'
        }

    pWeighted.x_range.range_padding = pWeighted.y_range.range_padding = 0
    pUnWeighted.x_range.range_padding = pUnWeighted.y_range.range_padding = 0
    spans = [
        Span(location=-1,
             dimension='height',
             line_color='navy',
             line_dash='dashed'),
        Span(location=-1,
             dimension='height',
             line_color='chocolate',
             line_dash='dashed')
    ]
    pGrad = figure(height=350,
                   width=350,
                   toolbar_location=None,
                   title='Gradients')
    pGrad.add_layout(spans[0])
    pGrad.add_layout(spans[1])
    CDSFirst = ColumnDataSource({'t': [0, 0, .5, .5], 'amp': [0, 1, 1, 0]})
    CDSSecond = ColumnDataSource({'t': [.5, .5, 1, 1], 'amp': [0, -1, -1, 0]})
    pGrad.line(x='t', y='amp', color='navy', line_width=2, source=CDSFirst)
    pGrad.line(x='t',
               y='amp',
               color='chocolate',
               line_width=2,
               source=CDSSecond)

    pCompass = figure(height=350,
                      width=350,
                      toolbar_location=None,
                      title='Fat vectors',
                      x_range=(-1.1, 1.1),
                      y_range=(-1.1, 1.1))
    pCompass.circle(0, 0, radius=1, fill_color=None, line_color='black')
    CDSArrow = [
        ColumnDataSource({
            'x': [0, 0],
            'y': [0, 0]
        }),
        ColumnDataSource({
            'x': [0, 0],
            'y': [0, 0]
        })
    ]
    pCompass.line(x='x', y='y', color='navy', line_width=2, source=CDSArrow[0])
    pCompass.line(x='x',
                  y='y',
                  color='chocolate',
                  line_width=2,
                  source=CDSArrow[1])

    slider = Slider(start=np.min(acquistionTimes),
                    end=np.max(acquistionTimes),
                    value=np.max(acquistionTimes),
                    step=acquistionTimes[1] - acquistionTimes[0],
                    title="Available acquisition time [ms]")

    hoverCallback = CustomJS(args={
        'dephasingTimes': dephasingTimes,
        'firstEchoFractions': firstEchoFractions,
        'partialFourierFactors': partialFourierFactors,
        'spans': spans,
        'arrows': CDSArrow,
        'first': CDSFirst,
        'slider': slider,
        'second': CDSSecond
    },
                             code="""
                            if ( isFinite(cb_data['geometry']['x']) && isFinite(cb_data['geometry']['y']) ) {
                                let ta = slider.value / 1000.0
                                if (typeof window.taIdx == 'undefined') {
                                    window.taIdx = dephasingTimes[0][0].length - 1;
                                }
                                let fIdx = Math.floor(cb_data["geometry"]['x']) % firstEchoFractions.length;
                                let pfIdx = Math.floor(cb_data["geometry"]['y']);
                                first.data.t[2] = first.data.t[3] = ta*firstEchoFractions[fIdx];
                                second.data.t[0] = second.data.t[1] = ta*firstEchoFractions[fIdx];
                                second.data.t[2] = second.data.t[3] = ta;
                                first.data.amp[1] = first.data.amp[2] = partialFourierFactors[pfIdx] / firstEchoFractions[fIdx];
                                second.data.amp[1] = second.data.amp[2] = - partialFourierFactors[pfIdx] / (1 - firstEchoFractions[fIdx]);
                                spans[0].location = ta/2.0 + dephasingTimes[pfIdx][fIdx][window.taIdx][0];
                                spans[1].location = ta/2.0 + dephasingTimes[pfIdx][fIdx][window.taIdx][1];
                                
                                spans[0].change.emit();
                                spans[1].change.emit();
                                first.change.emit();
                                second.change.emit();
                                let omega = 2*Math.PI*42.58*3*3.4
                                
                                for (let i = 0; i < 2; i++) {
                                    arrows[i].data.x[1] = Math.cos(omega*dephasingTimes[pfIdx][fIdx][window.taIdx][i]);
                                    arrows[i].data.y[1] = Math.sin(omega*dephasingTimes[pfIdx][fIdx][window.taIdx][i]);
                                    arrows[i].change.emit();
                                }
                                
                                window.fIdx = fIdx;
                                window.pfIdx = pfIdx;
                            }
                            """)
    pWeighted.add_tools(
        HoverTool(tooltips=[('index', '$index'), ('x', '$x'), ('y', '$y')],
                  callback=hoverCallback,
                  mode='mouse'))
    pUnWeighted.add_tools(
        HoverTool(tooltips=[('index', '$index'), ('x', '$x'), ('y', '$y')],
                  callback=hoverCallback,
                  mode='mouse'))
    sliderCallback = CustomJS(args={
        'acquistionTimes': acquistionTimes,
        'images': CDSimages,
        'NSA': NSA
    },
                              code="""
        window.taIdx = Math.floor((cb_obj.value - cb_obj.start ) / cb_obj.step)
        console.log(window.taIdx)
        images[0].data.imageData = [NSA.map(PF => PF.map(f => f[window.taIdx][0])).flat()]
        images[1].data.imageData = [NSA.map(PF => PF.map(f => f[window.taIdx][1])).flat()]
        images[0].change.emit();
        images[1].change.emit();
        """)
    slider.js_on_change('value', sliderCallback)
    if args.svg is True:
        from bokeh.io import export_svgs
        fname = 'pweighted.svg'
        print('Saving ' + fname)
        pWeighted.output_backend = "svg"
        export_svgs(pWeighted, filename=fname)
        print('Done')
    else:
        output_file(args.filename)
        combinedPlot = column(row(pWeighted, pUnWeighted),
                              row(pGrad, pCompass), slider)
        save(combinedPlot, filename=args.filename, title=args.title)
#---------------------------------------------------------------#
# Set Up Callbacks
callback = CustomJS(args=dict(s1=s1, s2=s2, s3=s3, s4=s4, paragraph=paragraph),
                    code="""
        var inds = cb_obj.value;
        var d1 = s1.data;
        var d2 = s2.data;
        var index = d2['image'];
        d1['image'] = index[inds];

        var months = s3.data['months'];
        var years = s4.data['years'];
        var month = months[inds % 12];
        var year = years[Math.floor(inds / 12)]
        title = `${month} ${year}`
        paragraph.text = title

        s1.change.emit();
    """)
year.js_on_change('value', callback)

callback_opacity = CustomJS(args=dict(r=r),
                            code="""
        choropleth_layer.glyph.global_alpha = cb_obj.value;
    """)
opacity.js_on_change('value', callback_opacity)
#---------------------------------------------------------------#
# Create the Layout
layout = row(column(paragraph, widgetbox(year), opacity), p, width=800)

show(layout)
Exemple #19
0
def plot_waveform_bokeh(filename,waveform_list,metadata_list,station_lat_list,\
                       station_lon_list, event_lat, event_lon, boundary_data, style_parameter):
    xlabel_fontsize = style_parameter['xlabel_fontsize']
    #
    map_station_location_bokeh = ColumnDataSource(data=dict(map_lat_list=station_lat_list,\
                                                            map_lon_list=station_lon_list))
    dot_default_index = 0
    selected_dot_on_map_bokeh = ColumnDataSource(data=dict(lat=[station_lat_list[dot_default_index]],\
                                                           lon=[station_lon_list[dot_default_index]],\
                                                           index=[dot_default_index]))
    map_view = Figure(plot_width=style_parameter['map_view_plot_width'], \
                      plot_height=style_parameter['map_view_plot_height'], \
                      y_range=[style_parameter['map_view_lat_min'],\
                    style_parameter['map_view_lat_max']], x_range=[style_parameter['map_view_lon_min'],\
                    style_parameter['map_view_lon_max']], tools=style_parameter['map_view_tools'],\
                    title=style_parameter['map_view_title'])
    # ------------------------------
    # add boundaries to map view
    # country boundaries
    map_view.multi_line(boundary_data['country']['longitude'],\
                        boundary_data['country']['latitude'],color='gray',\
                        line_width=2, level='underlay', nonselection_line_alpha=1.0,\
                        nonselection_line_color='gray')
    # marine boundaries
    map_view.multi_line(boundary_data['marine']['longitude'],\
                        boundary_data['marine']['latitude'],color='gray',\
                        level='underlay', nonselection_line_alpha=1.0,\
                        nonselection_line_color='gray')
    # shoreline boundaries
    map_view.multi_line(boundary_data['shoreline']['longitude'],\
                        boundary_data['shoreline']['latitude'],color='gray',\
                        line_width=2, nonselection_line_alpha=1.0, level='underlay',
                        nonselection_line_color='gray')
    # state boundaries
    map_view.multi_line(boundary_data['state']['longitude'],\
                        boundary_data['state']['latitude'],color='gray',\
                        level='underlay', nonselection_line_alpha=1.0,\
                        nonselection_line_color='gray')
    #
    map_view.triangle('map_lon_list', 'map_lat_list', source=map_station_location_bokeh, \
                      line_color='gray', size=style_parameter['marker_size'], fill_color='black',\
                      selection_color='black', selection_line_color='gray',\
                      selection_fill_alpha=1.0,\
                      nonselection_fill_alpha=1.0, nonselection_fill_color='black',\
                      nonselection_line_color='gray', nonselection_line_alpha=1.0)
    map_view.triangle('lon','lat', source=selected_dot_on_map_bokeh,\
                      size=style_parameter['selected_marker_size'], line_color='black',fill_color='red')
    map_view.asterisk([event_lon], [event_lat], size=style_parameter['event_marker_size'], line_width=3, line_color='red', \
                      fill_color='red')
    # change style
    map_view.title.text_font_size = style_parameter['title_font_size']
    map_view.title.align = 'center'
    map_view.title.text_font_style = 'normal'
    map_view.xaxis.axis_label = style_parameter['map_view_xlabel']
    map_view.xaxis.axis_label_text_font_style = 'normal'
    map_view.xaxis.axis_label_text_font_size = xlabel_fontsize
    map_view.xaxis.major_label_text_font_size = xlabel_fontsize
    map_view.yaxis.axis_label = style_parameter['map_view_ylabel']
    map_view.yaxis.axis_label_text_font_style = 'normal'
    map_view.yaxis.axis_label_text_font_size = xlabel_fontsize
    map_view.yaxis.major_label_text_font_size = xlabel_fontsize
    map_view.xgrid.grid_line_color = None
    map_view.ygrid.grid_line_color = None
    map_view.toolbar.logo = None
    map_view.toolbar_location = 'above'
    map_view.toolbar_sticky = False
    # --------------------------------------------------------
    max_waveform_length = 0
    max_waveform_amp = 0
    ncurve = len(waveform_list)
    for a_sta in waveform_list:
        for a_trace in a_sta:
            if len(a_trace) > max_waveform_length:
                max_waveform_length = len(a_trace)
            if np.max(np.abs(a_trace)) > max_waveform_amp:
                max_waveform_amp = np.max(np.abs(a_trace))
    #
    plotting_list = []
    for a_sta in waveform_list:
        temp = []
        for a_trace in a_sta:
            if len(a_trace) < max_waveform_length:
                a_trace = np.append(
                    a_trace, np.zeros([(max_waveform_length - len(a_trace)),
                                       1]))
            temp.append(list(a_trace))
        plotting_list.append(temp)
    #
    time_list = []
    for ista in range(len(plotting_list)):
        a_sta = plotting_list[ista]
        temp = []
        for itr in range(len(a_sta)):
            a_trace = a_sta[itr]
            delta = metadata_list[ista][itr]['delta']
            time = list(np.arange(len(a_trace)) * delta)
            temp.append(time)
        #
        time_list.append(temp)
    #
    reftime_label_list = []
    channel_label_list = []
    for ista in range(len(metadata_list)):
        temp_ref = []
        temp_channel = []
        a_sta = metadata_list[ista]
        for a_trace in a_sta:
            temp_ref.append('Starting from ' + a_trace['starttime'])
            temp_channel.append(a_trace['network'] + '_' + a_trace['station'] +
                                '_' + a_trace['channel'])
        reftime_label_list.append(temp_ref)
        channel_label_list.append(temp_channel)
    # --------------------------------------------------------
    curve_fig01 = Figure(plot_width=style_parameter['curve_plot_width'], plot_height=style_parameter['curve_plot_height'], \
                       y_range=(-max_waveform_amp*1.05,max_waveform_amp*1.05), \
                       x_range=(0,max_waveform_length),\
                    tools=['save','box_zoom','ywheel_zoom','xwheel_zoom','reset','crosshair','pan'])
    #
    curve_index = 0
    select_curve_data = plotting_list[dot_default_index][curve_index]
    select_curve_time = time_list[dot_default_index][curve_index]

    selected_curve_data_bokeh01 = ColumnDataSource(
        data=dict(time=select_curve_time, amp=select_curve_data))
    select_reftime_label = reftime_label_list[dot_default_index][curve_index]
    selected_reftime_label_bokeh01 = ColumnDataSource(data=dict(x=[style_parameter['curve_reftime_label_x']],\
                                                                y=[style_parameter['curve_reftime_label_y']],\
                                                                label=[select_reftime_label]))
    select_channel_label = channel_label_list[dot_default_index][curve_index]
    selected_channel_label_bokeh01 = ColumnDataSource(data=dict(x=[style_parameter['curve_channel_label_x']],\
                                                                y=[style_parameter['curve_channel_label_y']],\
                                                                label=[select_channel_label]))
    all_curve_data_bokeh = ColumnDataSource(
        data=dict(t=time_list, amp=plotting_list))
    all_reftime_label_bokeh = ColumnDataSource(data=dict(
        label=reftime_label_list))
    all_channel_label_bokeh = ColumnDataSource(data=dict(
        label=channel_label_list))
    # plot waveform
    curve_fig01.line('time','amp', source=selected_curve_data_bokeh01,\
                   line_color='black')
    # add refference time as a label
    curve_fig01.text('x', 'y', 'label', source=selected_reftime_label_bokeh01)
    # add channel label
    curve_fig01.text('x', 'y', 'label', source=selected_channel_label_bokeh01)
    # change style
    curve_fig01.title.text_font_size = style_parameter['title_font_size']
    curve_fig01.title.align = 'center'
    curve_fig01.title.text_font_style = 'normal'
    curve_fig01.xaxis.axis_label = style_parameter['curve_xlabel']
    curve_fig01.xaxis.axis_label_text_font_style = 'normal'
    curve_fig01.xaxis.axis_label_text_font_size = xlabel_fontsize
    curve_fig01.xaxis.major_label_text_font_size = xlabel_fontsize
    curve_fig01.yaxis.axis_label = style_parameter['curve_ylabel']
    curve_fig01.yaxis.axis_label_text_font_style = 'normal'
    curve_fig01.yaxis.axis_label_text_font_size = xlabel_fontsize
    curve_fig01.yaxis.major_label_text_font_size = xlabel_fontsize
    curve_fig01.toolbar.logo = None
    curve_fig01.toolbar_location = 'above'
    curve_fig01.toolbar_sticky = False
    # --------------------------------------------------------
    curve_fig02 = Figure(plot_width=style_parameter['curve_plot_width'], plot_height=style_parameter['curve_plot_height'], \
                       y_range=(-max_waveform_amp*1.05,max_waveform_amp*1.05), \
                       x_range=(0,max_waveform_length),\
                    tools=['save','box_zoom','ywheel_zoom','xwheel_zoom','reset','crosshair','pan'])
    #
    curve_index = 1
    select_curve_data = plotting_list[dot_default_index][curve_index]
    select_curve_time = time_list[dot_default_index][curve_index]
    selected_curve_data_bokeh02 = ColumnDataSource(
        data=dict(time=select_curve_time, amp=select_curve_data))
    select_channel_label = channel_label_list[dot_default_index][curve_index]
    selected_channel_label_bokeh02 = ColumnDataSource(data=dict(x=[style_parameter['curve_channel_label_x']],\
                                                                y=[style_parameter['curve_channel_label_y']],\
                                                                label=[select_channel_label]))
    # plot waveform
    curve_fig02.line('time','amp', source=selected_curve_data_bokeh02,\
                   line_color='black')
    # add channel label
    curve_fig02.text('x', 'y', 'label', source=selected_channel_label_bokeh02)
    # change style
    curve_fig02.title.text_font_size = style_parameter['title_font_size']
    curve_fig02.title.align = 'center'
    curve_fig02.title.text_font_style = 'normal'
    curve_fig02.xaxis.axis_label = style_parameter['curve_xlabel']
    curve_fig02.xaxis.axis_label_text_font_style = 'normal'
    curve_fig02.xaxis.axis_label_text_font_size = xlabel_fontsize
    curve_fig02.xaxis.major_label_text_font_size = xlabel_fontsize
    curve_fig02.yaxis.axis_label = style_parameter['curve_ylabel']
    curve_fig02.yaxis.axis_label_text_font_style = 'normal'
    curve_fig02.yaxis.axis_label_text_font_size = xlabel_fontsize
    curve_fig02.yaxis.major_label_text_font_size = xlabel_fontsize
    curve_fig02.toolbar.logo = None
    curve_fig02.toolbar_location = 'above'
    curve_fig02.toolbar_sticky = False
    # --------------------------------------------------------
    curve_fig03 = Figure(plot_width=style_parameter['curve_plot_width'], plot_height=style_parameter['curve_plot_height'], \
                       y_range=(-max_waveform_amp*1.05,max_waveform_amp*1.05), \
                       x_range=(0,max_waveform_length),\
                    tools=['save','box_zoom','ywheel_zoom','xwheel_zoom','reset','crosshair','pan'])
    #
    curve_index = 2
    select_curve_data = plotting_list[dot_default_index][curve_index]
    select_curve_time = time_list[dot_default_index][curve_index]
    selected_curve_data_bokeh03 = ColumnDataSource(
        data=dict(time=select_curve_time, amp=select_curve_data))
    select_channel_label = channel_label_list[dot_default_index][curve_index]
    selected_channel_label_bokeh03 = ColumnDataSource(data=dict(x=[style_parameter['curve_channel_label_x']],\
                                                                y=[style_parameter['curve_channel_label_y']],\
                                                                label=[select_channel_label]))
    # plot waveform
    curve_fig03.line('time','amp', source=selected_curve_data_bokeh03,\
                   line_color='black')
    # add channel label
    curve_fig03.text('x', 'y', 'label', source=selected_channel_label_bokeh03)
    # change style
    curve_fig03.title.text_font_size = style_parameter['title_font_size']
    curve_fig03.title.align = 'center'
    curve_fig03.title.text_font_style = 'normal'
    curve_fig03.xaxis.axis_label = style_parameter['curve_xlabel']
    curve_fig03.xaxis.axis_label_text_font_style = 'normal'
    curve_fig03.xaxis.axis_label_text_font_size = xlabel_fontsize
    curve_fig03.xaxis.major_label_text_font_size = xlabel_fontsize
    curve_fig03.yaxis.axis_label = style_parameter['curve_ylabel']
    curve_fig03.yaxis.axis_label_text_font_style = 'normal'
    curve_fig03.yaxis.axis_label_text_font_size = xlabel_fontsize
    curve_fig03.yaxis.major_label_text_font_size = xlabel_fontsize
    curve_fig03.toolbar.logo = None
    curve_fig03.toolbar_location = 'above'
    curve_fig03.toolbar_sticky = False
    # --------------------------------------------------------
    map_station_location_js = CustomJS(args=dict(selected_dot_on_map_bokeh=selected_dot_on_map_bokeh,\
                                                            map_station_location_bokeh=map_station_location_bokeh,\
                                                            selected_curve_data_bokeh01=selected_curve_data_bokeh01,\
                                                            selected_curve_data_bokeh02=selected_curve_data_bokeh02,\
                                                            selected_curve_data_bokeh03=selected_curve_data_bokeh03,\
                                                            selected_channel_label_bokeh01=selected_channel_label_bokeh01,\
                                                            selected_channel_label_bokeh02=selected_channel_label_bokeh02,\
                                                            selected_channel_label_bokeh03=selected_channel_label_bokeh03,\
                                                            selected_reftime_label_bokeh01=selected_reftime_label_bokeh01,\
                                                            all_reftime_label_bokeh=all_reftime_label_bokeh,\
                                                            all_channel_label_bokeh=all_channel_label_bokeh,\
                                                            all_curve_data_bokeh=all_curve_data_bokeh), code="""
    var inds = cb_obj.indices
    
    selected_dot_on_map_bokeh.data['index'] = [inds]
    var new_loc = map_station_location_bokeh.data
    
    selected_dot_on_map_bokeh.data['lat'] = [new_loc['map_lat_list'][inds]]
    selected_dot_on_map_bokeh.data['lon'] = [new_loc['map_lon_list'][inds]]
    
    selected_dot_on_map_bokeh.change.emit()
    
    selected_curve_data_bokeh01.data['t'] = all_curve_data_bokeh.data['t'][inds][0]
    selected_curve_data_bokeh01.data['amp'] = all_curve_data_bokeh.data['amp'][inds][0]

    selected_curve_data_bokeh01.change.emit()
    
    selected_curve_data_bokeh02.data['t'] = all_curve_data_bokeh.data['t'][inds][1]
    selected_curve_data_bokeh02.data['amp'] = all_curve_data_bokeh.data['amp'][inds][1]

    selected_curve_data_bokeh02.change.emit()
    
    selected_curve_data_bokeh03.data['t'] = all_curve_data_bokeh.data['t'][inds][2]
    selected_curve_data_bokeh03.data['amp'] = all_curve_data_bokeh.data['amp'][inds][2]

    selected_curve_data_bokeh03.change.emit()
    
    selected_reftime_label_bokeh01.data['label'] = [all_reftime_label_bokeh.data['label'][inds][0]]
    
    selected_reftime_label_bokeh01.change.emit()
    
    selected_channel_label_bokeh01.data['label'] = [all_channel_label_bokeh.data['label'][inds][0]]
    
    selected_channel_label_bokeh01.change.emit()
    
    selected_channel_label_bokeh02.data['label'] = [all_channel_label_bokeh.data['label'][inds][1]]
    
    selected_channel_label_bokeh02.change.emit()
    
    selected_channel_label_bokeh03.data['label'] = [all_channel_label_bokeh.data['label'][inds][2]]
    
    selected_channel_label_bokeh03.change.emit()
    """)
    #
    map_station_location_bokeh.selected.js_on_change('indices',
                                                     map_station_location_js)
    #
    curve_slider_callback = CustomJS(args=dict(selected_dot_on_map_bokeh=selected_dot_on_map_bokeh,\
                                                map_station_location_bokeh=map_station_location_bokeh,\
                                                selected_curve_data_bokeh01=selected_curve_data_bokeh01,\
                                                selected_curve_data_bokeh02=selected_curve_data_bokeh02,\
                                                selected_curve_data_bokeh03=selected_curve_data_bokeh03,\
                                                selected_channel_label_bokeh01=selected_channel_label_bokeh01,\
                                                selected_channel_label_bokeh02=selected_channel_label_bokeh02,\
                                                selected_channel_label_bokeh03=selected_channel_label_bokeh03,\
                                                selected_reftime_label_bokeh01=selected_reftime_label_bokeh01,\
                                                all_reftime_label_bokeh=all_reftime_label_bokeh,\
                                                all_channel_label_bokeh=all_channel_label_bokeh,\
                                                all_curve_data_bokeh=all_curve_data_bokeh),code="""
    var inds = Math.round(cb_obj.value)
    
    selected_dot_on_map_bokeh.data['index'] = [inds]
    var new_loc = map_station_location_bokeh.data
    
    selected_dot_on_map_bokeh.data['lat'] = [new_loc['map_lat_list'][inds]]
    selected_dot_on_map_bokeh.data['lon'] = [new_loc['map_lon_list'][inds]]
    
    selected_dot_on_map_bokeh.change.emit()
    
    selected_curve_data_bokeh01.data['t'] = all_curve_data_bokeh.data['t'][inds][0]
    selected_curve_data_bokeh01.data['amp'] = all_curve_data_bokeh.data['amp'][inds][0]

    selected_curve_data_bokeh01.change.emit()
    
    selected_curve_data_bokeh02.data['t'] = all_curve_data_bokeh.data['t'][inds][1]
    selected_curve_data_bokeh02.data['amp'] = all_curve_data_bokeh.data['amp'][inds][1]

    selected_curve_data_bokeh02.change.emit()
    
    selected_curve_data_bokeh03.data['t'] = all_curve_data_bokeh.data['t'][inds][2]
    selected_curve_data_bokeh03.data['amp'] = all_curve_data_bokeh.data['amp'][inds][2]

    selected_curve_data_bokeh03.change.emit()
    
    selected_reftime_label_bokeh01.data['label'] = [all_reftime_label_bokeh.data['label'][inds][0]]
    
    selected_reftime_label_bokeh01.change.emit()
    
    selected_channel_label_bokeh01.data['label'] = [all_channel_label_bokeh.data['label'][inds][0]]
    
    selected_channel_label_bokeh01.change.emit()
    
    selected_channel_label_bokeh02.data['label'] = [all_channel_label_bokeh.data['label'][inds][1]]
    
    selected_channel_label_bokeh02.change.emit()
    
    selected_channel_label_bokeh03.data['label'] = [all_channel_label_bokeh.data['label'][inds][2]]
    
    selected_channel_label_bokeh03.change.emit()
    """)
    curve_slider = Slider(start=0, end=ncurve-1, value=style_parameter['curve_default_index'], \
                          step=1, title=style_parameter['curve_slider_title'], width=style_parameter['map_view_plot_width'],\
                          height=50)
    curve_slider.js_on_change('value', curve_slider_callback)
    curve_slider_callback.args['curve_index'] = curve_slider
    # ==============================
    # annotating text
    annotating_fig01 = Div(text=style_parameter['annotating_html01'], \
        width=style_parameter['annotation_plot_width'], height=style_parameter['annotation_plot_height'])
    annotating_fig02 = Div(text=style_parameter['annotating_html02'],\
        width=style_parameter['annotation_plot_width'], height=style_parameter['annotation_plot_height'])
    # ==============================
    output_file(filename,
                title=style_parameter['html_title'],
                mode=style_parameter['library_source'])
    #
    left_fig = Column(curve_slider,
                      map_view,
                      annotating_fig01,
                      width=style_parameter['left_column_width'])

    right_fig = Column(curve_fig01,
                       curve_fig02,
                       curve_fig03,
                       annotating_fig02,
                       width=style_parameter['right_column_width'])
    layout = Row(left_fig, right_fig)
    save(layout)
    var n = cb_obj.value;
    
    var data1 = S1.data;
    
    for(var i = 0; i < data1.y.length; i++)
    {
          data1['noise'][i] = 'False';
       if(Math.abs(data1.x[i] + data1.y[i] - data1.RRSec_n_minus_2[i]) <= n)
       {
          data1['noise'][i] = 'True';
       }
    }
    S1.change.emit();
""")
slider = Slider(start=0.0, end=0.5,value=0.0, step=0.1, title="noise%",height=65)
slider.js_on_change('value', callback)
callback.args["slider"] = slider

# ------------------------------------------ FIGURE 1 ----------------------------------------------------------


# ------------------------------------------ FIGURE 2 ----------------------------------------------------------
p2 = figure(output_backend="webgl", plot_width=700, plot_height=700, x_range=p1.x_range, y_range=p1.y_range)

p2.xaxis.axis_label = "RRn-1 (seconds)"
p2.xaxis.axis_label_text_font_size = "25px"
# p2.xaxis.axis_label_text_color = "#22ACE2"
p2.xaxis.axis_label_text_font = font_name
p2.xaxis.major_label_text_font = font_name

p2.yaxis.axis_label = "RRn (seconds)"
Exemple #21
0
def create_vp_plot(data):
    ds = ColumnDataSource(data)
    # tools_to_show = "box_zoom, pan,save, hover, reset, wheel_zoom"
    var_label = '@{' + str(data.columns[0] + '}')
    try:
        var_tooltip_label = str(data.variable_metadata['long_name'])
    except KeyError:
        var_tooltip_label = str(data.variable_metadata['standard_name'])
    try:
        units = list({'unit', 'units'}.intersection(data.variable_metadata))[0]
        x_axis_label = " ".join(
            [var_tooltip_label, '[', data.variable_metadata[units], ']'])
    except IndexError:
        print('no units found')
        x_axis_label = var_tooltip_label
    p = figure(toolbar_location="above",
               tools="crosshair,box_zoom, pan,save, reset, wheel_zoom",
               x_axis_type="linear",
               x_axis_label=x_axis_label)
    p.sizing_mode = 'stretch_width'
    if len(data.dataset_metadata['dimension']) == 2:
        try:
            vertical_level, time_level = data.dataset_metadata['dimension']
        except KeyError:
            vertical_level, time_level = ('obsdepth', 'time')
    else:
        vertical_level = data.dataset_metadata['dimension'][0]
    try:
        var_tooltip_label = str(data.variable_metadata['long_name'])
    except KeyError:
        var_tooltip_label = str(data.variable_metadata['standard_name'])
    # if " " in var_label:
    #     var_label = '@{' + var_label + '}'
    # else:
    #     var_label = var_label
    # var_label = var_label
    hover = HoverTool(tooltips=[("Depth", "@" + vertical_level),
                                (var_tooltip_label, var_label)])

    p.add_tools(hover)
    p.y_range.flipped = True
    p.min_border_left = 80
    p.min_border_right = 80
    p.background_fill_color = "SeaShell"
    p.background_fill_alpha = 0.5
    line_renderer = p.line(data.columns[0],
                           vertical_level,
                           source=ds,
                           line_alpha=0.6, color='RoyalBlue',
                           )
    point_renderer = p.circle(data.columns[0],
                              vertical_level,
                              source=ds,
                              color='RoyalBlue',
                              size=3,
                              fill_alpha=0.5,
                              fill_color='white',
                              legend_label=data.columns[0],
                              )
    p.legend.location = "top_left"
    p.legend.click_policy = "hide"
    if len(list(data.columns)) >= 2:
        # Div
        html_text = get_datetime_string(list(data.columns)[0])
        par = Div(text=html_text)
        # Slider Labels
        end_label = Div(text=list(data.columns)[-1])
        start_label = Div(text=list(data.columns)[0])
        # Buttons
        left_btn = Button(label='<', width=30)
        right_btn = Button(label='>', width=30)
        # Spacer
        sp = Spacer(width=50)
        # Slider Labels
        end_label = Div(text=list(data.columns)[-1].split('T')[0] + \
                             '<br>' \
                             + list(data.columns)[-1].split('T')[1],
                        style={'text-align': 'right'})
        start_label = Div(text=list(data.columns)[0].split('T')[0] + \
                               '<br>' \
                               + list(data.columns)[0].split('T')[1],
                          style={'text-align': 'left'})

        select = Select(title="Profile-record:",
                        options=list(data.columns),
                        value=list(data.columns)[0])
        slider = Slider(title="Profile #",
                        value=0,
                        start=0,
                        end=len(data.columns) - 1,
                        step=1,
                        show_value=True,
                        tooltips=False)  #

        select_handler = CustomJS(args=dict(line_renderer=line_renderer,
                                            point_renderer=point_renderer,
                                            slider=slider,
                                            par=par),
                                  code="""
           line_renderer.glyph.x = {field: cb_obj.value};
           point_renderer.glyph.x = {field: cb_obj.value};
           slider.value = cb_obj.options.indexOf(cb_obj.value);
           var date_time = cb_obj.value.split("T");
           var date = date_time[0];
           var time = date_time[1];
           par.text = `<ul style="text-align:left;"><li>Date: <b>`+date+`</b></li><li>Time: <b>`+time+`</b></li></ul>`;
        """)
        select.js_on_change('value', select_handler)
        slider_handler = CustomJS(args=dict(select=select),
                                  code="""
           select.value = select.options[cb_obj.value];
        """)

        slider.js_on_change('value', slider_handler)

        # Left button cb
        left_btn_args = dict(slider=slider)
        left_btn_handler = """
        if(slider.value > slider.start) {
            slider.value = slider.value - 1;
            slider.change.emit();
        }
        """
        left_btn_callback = CustomJS(args=left_btn_args, code=left_btn_handler)
        left_btn.js_on_click(left_btn_callback)

        # Right button cb
        right_btn_args = dict(slider=slider)
        right_btn_handler = """
        if(slider.value <= slider.end - 1) {
            slider.value = slider.value + 1;
            slider.change.emit();
        }
        """
        right_btn_callback = CustomJS(args=right_btn_args, code=right_btn_handler)
        right_btn.js_on_click(right_btn_callback)

        # buttons = row(left_btn, right_btn)
        # inputs = row(sp,slider,buttons, par)
        # return column(select, slider, p, par, sizing_mode="stretch_width")
        # return column(p, select, inputs, sizing_mode="stretch_width")
        # Set up layouts and add to document
        # slider_wrapper = layout([
        #    [slider],
        #    [start_label, Spacer(sizing_mode="stretch_width"), end_label]
        # ])
        slider_wrapper = layout([
            [sp, sp, slider, left_btn, right_btn, par],
            [sp, start_label, sp, sp, end_label, sp, sp],
        ])
        # buttons = row(left_btn, right_btn)
        # inputs = row(sp, start_label, left_btn, sp, slider, sp, right_btn, end_label, par)

        return column(select, p, slider_wrapper, sizing_mode="stretch_width")
    else:
        return column(p, sizing_mode="stretch_width")
        Sexpsq+=exp[i]*exp[i];
        
        }
        
        TSR=(95*Sytexp-Sexp*Syt)/Math.sqrt(95*Sytsq-Syt*Syt)/Math.sqrt(95*Sexpsq-Sexp*Sexp)
        
        TSR=Math.round(TSR*100)/100
        
        corlabTS.text='TS R='.concat(TSR.toString())

        source.trigger('change');
    """)


cbox.js_on_change('active', cb)
slider.js_on_change('value', cb_slider)
rb_group.js_on_change('active', cb_ff)
rb_group2.js_on_change('active', cb_struct)
div = Div(text="""<b>Figure: Comparison of H-SASA profiles calculated with different parameters with experimental cleavage frequencies. Top strand (TS) of <i>S. cerevisiae</i> centromeric nucleosome with 601TA DNA sequence.</b>
<br><br>
This interactive plot allows to explore dependence of H-SASA profiles on different paramters. Correlation coefficient is interactively displayed in the bottom left corner.
Use contols above to choose between following paramters:<br><br>
<b>Structure used for calculations</b>: X-H nuclear - orginal X-ray derived structure, hydrogen atoms added via REDUCE program with nuclear distances for X-H bond length;
X-H electron - original X-ray derived structure, hydrogen atoms added via REDUCE program with electron cloud distances for X-H bond length;
MD average - average profile for 50 frames spaced 1 ns apart from molecular dynamics simulations with CHARMM36 force field with NAMD.
<br><br>
<b>Radii of atoms used for SASA calculations</b>: FreeSASA Default - default radii in FreeSASA program; CHARMM36-rmin - van der Waals radii (Rmin paramter) of atoms as defined by CHARMM36 force field;
AMBER10-rmin - van der Waals radii (Rmin paramter) of atoms as defined by AMBER10 force field.
<br><br>
<b>Contributions of deoxyribose hydrogen atoms</b>: SASA of deoxyribose hydrogen atoms selected will be included in calculation of H-SASA profile.
<br><br>
Exemple #23
0
mass_finder_mass = Slider(value=100, start=0.0, end=1000.0, step=10.0, title='Mass of Complex (kDa)',name='gau_sigma', width=250, height=30)

mass_finder_exact_mass_text = Div(text= "Enter exact Mass (Da)", width= 150, height=30 )
mass_finder_exact_mass_sele = TextInput(value=str(mass_finder_mass.value*1000), disabled=False, width=100, height=30)

mass_finder_line_text = Div(text= "Show mz prediction", width= 150, height=30 )
mass_finder_line_sele = Toggle(label='off', active=False, width=100, height=30, callback=toggle_cb)

mass_finder_cb =CustomJS(args=dict(mass_finder_line_sele=mass_finder_line_sele, raw_mz=raw_mz, mass_finder_data=mass_finder_data, mass_finder_exact_mass_sele=mass_finder_exact_mass_sele, mass_finder_mass=mass_finder_mass, mass_finder_range_slider=mass_finder_range_slider, mfl=mfl), code=open(os.path.join(os.getcwd(), 'JS_Functions', "mass_finder_cb.js")).read())
mass_finder_exact_cb =CustomJS(args=dict(mass_finder_line_sele=mass_finder_line_sele, mass_finder_exact_mass_sele=mass_finder_exact_mass_sele, mass_finder_mass=mass_finder_mass), code=open(os.path.join(os.getcwd(), 'JS_Functions', "mass_finder_exact_cb.js")).read())
mass_finder_exact_mass_sele.js_on_change('value', mass_finder_exact_cb)

mass_finder_column=Column(mass_finder_header,mass_finder_mass, mass_finder_range_slider, Row(mass_finder_exact_mass_text,mass_finder_exact_mass_sele), Row(mass_finder_line_text, mass_finder_line_sele), visible=False)
mass_finder.js_link('active', mass_finder_column, 'visible')
mass_finder_line_sele.js_link('active', mfl, 'visible')
mass_finder_mass.js_on_change('value', mass_finder_cb)
mass_finder_line_sele.js_on_change('active', mass_finder_cb)
mass_finder_range_slider.js_on_change('value',mass_finder_cb)
### DATA PROCESSING ###

cropping = Div(text= " Range mz:", width= 150, height=30 )
# crop_max = Div(text= " ", width= 150, height=30 )
gau_name = Div(text= " Gaussian Smoothing:", width= 150, height=30 )
n_smooth_name = Div(text= " Repeats of Smoothing:", width= 150, height=30 )
# bin_name = Div(text= " Bin Every:", width= 150, height=30 )
int_name = Div(text= " Intensity Threshold (%)", width= 150, height=30 )
sub_name = Select(options=['Substract Minimum', 'Substract Line', 'Substract Curved'], name='sub_mode', value='Substract Minimum', width= 150, height=30 )
# add_name = Div(text= " Adduct Mass (Da)", width= 150, height=30 )
# dat_name = Div(text= " Data Reduction (%)", width= 150, height=30 )
#pro_name = Div(text= " bla", width= 150, height=30 )
dt_name  = Div(text= " <h2>Data Processing</h2>", height=45 )
Exemple #24
0
        'slider': slider,
        'mvmt_name_mapper': mvmt_name_mapper
    },
             code="""            
            const select_value = mvmt_name_mapper[cb_obj.value]
            p1_cds.data = cds_pr_cal[select_value][slider.value].data
            p1_cds.change.emit()
        """))

slider.js_on_change(
    'value',
    CustomJS(args={
        'p1_cds': cds_pr_cal['Y'][1],
        'cds_pr_cal': cds_pr_cal,
        'select': select,
        'mvmt_name_mapper': mvmt_name_mapper
    },
             code="""
            const select_value = mvmt_name_mapper[select.value]
            p1_cds.data = cds_pr_cal[select_value][cb_obj.value].data
            p1_cds.change.emit()
        """))

# Horozontal range slider for gitub calendar plot
date_slider_callback = CustomJS(args=dict(p=pcal),
                                code="""
    var a = cb_obj.value;
    var delta = p.x_range.end - p.x_range.start
    p.x_range.start = a;
    p.x_range.end = a+delta;
""")
Exemple #25
0
def get_x_eval_selectors_list(result, active_list, x_eval):
    # Returns a list of selectors. The selectors are sliders for numerical values and dropdown menus
    # ("Select" object) for categorical values. The selectors are interactive with callbacks everytime
    # a changed by the user
    global x_eval_selectors_values

    bounds = result.space.bounds  # Used for defining what values can be selected
    x_eval_selectors = []
    n = 0  # Index of the plots. Example: If only parameter 3 and 5 is being plotted
    # the selectors for these parameters still have index n = 0 and n= 1.
    for i in active_list:  # Only get selecters that is going to be shown in GUI
        if isinstance(result.space.dimensions[i], Categorical):  # Categorical
            cats = list(result.space.dimensions[i].categories)  # Categories
            # Create a "Select" object which is a type of dropdown menu
            # This object gets a title equal to the parameter number, and the value is set to
            # x_eval
            select = Select(title='X' + str(i),
                            value=x_eval[i],
                            options=cats,
                            width=200,
                            height=15)
            # Here we define a callback that updates the appropiate red markers by changing
            # with the current value of the selector by changing the global "source" variable
            # The callback function is written in javascript
            select.js_on_change(
                'value',
                CustomJS(args=dict(source=source, n=n, cats=cats),
                         code="""
                // Convert categorical to index
                var ind = cats.indexOf(cb_obj.value); 
                // Change red line in diagonal plots
                source['reds'][n][n]['location'] = ind + 0.5;
                // Change red markers in all contour plots
                // First we change the plots in a vertical direction
                for (i = n+1; i < source.reds.length; i++) { 
                    source.reds[i][n].data.x = [ind + 0.5] ;
                    source.reds[i][n].change.emit()
                }
                // Then in a horizontal direction
                for (j = 0; j < n; j++) { 
                    source.reds[n][j].data.y = [ind + 0.5] ;
                    source.reds[n][j].change.emit();
                }
                """))
            x_eval_selectors.append(select)
            # We update the global selector values
            x_eval_selectors_values[i] = x_eval[i]
        else:  # Numerical
            # For numerical values we create a slider
            # Minimum and maximum values for slider
            start = bounds[i][0]
            end = bounds[i][1]
            # We change the stepsize according to the range of the slider
            step = get_step_size(start, end)
            slider = Slider(start=start,
                            end=end,
                            value=x_eval[i],
                            step=step,
                            title='X' + str(i),
                            width=200,
                            height=30)
            # javascript callback function that gets called everytime a user changes the slider value
            slider.js_on_change(
                'value',
                CustomJS(args=dict(source=source, n=n),
                         code="""
                source.reds[n][n].location = cb_obj.value;
                source.reds[n][n].change.emit()
                for (i = n+1; i < source.reds.length; i++) { 
                    source.reds[i][n].data.x = [cb_obj.value] ;
                    source.reds[i][n].change.emit();
                }
                for (j = 0; j < n; j++) { 
                    source.reds[n][j].data.y = [cb_obj.value] ;
                    source.reds[n][j].change.emit();
                }
                """))
            x_eval_selectors.append(slider)
            x_eval_selectors_values[i] = x_eval[i]
        n += 1
    return x_eval_selectors
Exemple #26
0
def plotspectra(spectra,
                zcatalog=None,
                model=None,
                notebook=False,
                title=None):
    '''
    TODO: document
    '''

    if notebook:
        bk.output_notebook()

    #- If inputs are frames, convert to a spectra object
    if isinstance(spectra, list) and isinstance(spectra[0],
                                                desispec.frame.Frame):
        spectra = frames2spectra(spectra)
        frame_input = True
    else:
        frame_input = False

    if frame_input and title is None:
        meta = spectra.meta
        title = 'Night {} ExpID {} Spectrograph {}'.format(
            meta['NIGHT'],
            meta['EXPID'],
            meta['CAMERA'][1],
        )

    #- Gather spectra into ColumnDataSource objects for Bokeh
    nspec = spectra.num_spectra()
    cds_spectra = list()

    for band in spectra.bands:
        #- Set masked bins to NaN so that Bokeh won't plot them
        bad = (spectra.ivar[band] == 0.0) | (spectra.mask[band] != 0)
        spectra.flux[band][bad] = np.nan

        cdsdata = dict(
            origwave=spectra.wave[band].copy(),
            plotwave=spectra.wave[band].copy(),
        )

        for i in range(nspec):
            key = 'origflux' + str(i)
            cdsdata[key] = spectra.flux[band][i]

        cdsdata['plotflux'] = cdsdata['origflux0']

        cds_spectra.append(bk.ColumnDataSource(cdsdata, name=band))

    #- Reorder zcatalog to match input targets
    #- TODO: allow more than one zcatalog entry with different ZNUM per targetid
    targetids = spectra.target_ids()
    if zcatalog is not None:
        ii = np.argsort(np.argsort(targetids))
        jj = np.argsort(zcatalog['TARGETID'])
        kk = jj[ii]
        zcatalog = zcatalog[kk]

        #- That sequence of argsorts may feel like magic,
        #- so make sure we got it right
        assert np.all(zcatalog['TARGETID'] == targetids)
        assert np.all(zcatalog['TARGETID'] == spectra.fibermap['TARGETID'])

        #- Also need to re-order input model fluxes
        if model is not None:
            mwave, mflux = model
            model = mwave, mflux[kk]

    #- Gather models into ColumnDataSource objects, row matched to spectra
    if model is not None:
        mwave, mflux = model
        model_obswave = mwave.copy()
        model_restwave = mwave.copy()
        cds_model_data = dict(
            origwave=mwave.copy(),
            plotwave=mwave.copy(),
            plotflux=np.zeros(len(mwave)),
        )

        for i in range(nspec):
            key = 'origflux' + str(i)
            cds_model_data[key] = mflux[i]

        cds_model_data['plotflux'] = cds_model_data['origflux0']
        cds_model = bk.ColumnDataSource(cds_model_data)
    else:
        cds_model = None

    #- Subset of zcatalog and fibermap columns into ColumnDataSource
    target_info = list()
    for i, row in enumerate(spectra.fibermap):
        target_bit_names = ' '.join(desi_mask.names(row['DESI_TARGET']))
        txt = 'Target {}: {}'.format(row['TARGETID'], target_bit_names)
        if zcatalog is not None:
            txt += '<BR/>{} z={:.4f} ± {:.4f}  ZWARN={}'.format(
                zcatalog['SPECTYPE'][i],
                zcatalog['Z'][i],
                zcatalog['ZERR'][i],
                zcatalog['ZWARN'][i],
            )
        target_info.append(txt)

    cds_targetinfo = bk.ColumnDataSource(dict(target_info=target_info),
                                         name='targetinfo')
    if zcatalog is not None:
        cds_targetinfo.add(zcatalog['Z'], name='z')

    plot_width = 800
    plot_height = 400
    # tools = 'pan,box_zoom,wheel_zoom,undo,redo,reset,save'
    tools = 'pan,box_zoom,wheel_zoom,reset,save'
    fig = bk.figure(height=plot_height,
                    width=plot_width,
                    title=title,
                    tools=tools,
                    toolbar_location='above',
                    y_range=(-10, 20))
    fig.toolbar.active_drag = fig.tools[1]  #- box zoom
    fig.toolbar.active_scroll = fig.tools[2]  #- wheel zoom
    fig.xaxis.axis_label = 'Wavelength [Å]'
    fig.yaxis.axis_label = 'Flux'
    fig.xaxis.axis_label_text_font_style = 'normal'
    fig.yaxis.axis_label_text_font_style = 'normal'
    colors = dict(b='#1f77b4', r='#d62728', z='maroon')

    data_lines = list()
    for spec in cds_spectra:
        lx = fig.line('plotwave',
                      'plotflux',
                      source=spec,
                      line_color=colors[spec.name])
        data_lines.append(lx)

    if cds_model is not None:
        model_lines = list()
        lx = fig.line('plotwave',
                      'plotflux',
                      source=cds_model,
                      line_color='black')
        model_lines.append(lx)

        legend = Legend(items=[
            ("data",
             data_lines[-1::-1]),  #- reversed to get blue as lengend entry
            ("model", model_lines),
        ])
    else:
        legend = Legend(items=[
            ("data",
             data_lines[-1::-1]),  #- reversed to get blue as lengend entry
        ])

    fig.add_layout(legend, 'center')
    fig.legend.click_policy = 'hide'  #- or 'mute'

    #- Zoom figure around mouse hover of main plot
    zoomfig = bk.figure(
        height=plot_height // 2,
        width=plot_height // 2,
        y_range=fig.y_range,
        x_range=(5000, 5100),
        # output_backend="webgl",
        toolbar_location=None,
        tools=[])

    for spec in cds_spectra:
        zoomfig.line('plotwave',
                     'plotflux',
                     source=spec,
                     line_color=colors[spec.name],
                     line_width=1,
                     line_alpha=1.0)

    if cds_model is not None:
        zoomfig.line('plotwave',
                     'plotflux',
                     source=cds_model,
                     line_color='black')

    #- Callback to update zoom window x-range
    zoom_callback = CustomJS(args=dict(zoomfig=zoomfig),
                             code="""
            zoomfig.x_range.start = cb_obj.x - 100;
            zoomfig.x_range.end = cb_obj.x + 100;
        """)

    fig.js_on_event(bokeh.events.MouseMove, zoom_callback)

    #-----
    #- Emission and absorption lines
    z = zcatalog['Z'][0] if (zcatalog is not None) else 0.0
    line_data, lines, line_labels = add_lines(fig, z=z)

    #-----
    #- Add widgets for controling plots
    z1 = np.floor(z * 100) / 100
    dz = z - z1
    zslider = Slider(start=0.0, end=4.0, value=z1, step=0.01, title='Redshift')
    dzslider = Slider(start=0.0,
                      end=0.01,
                      value=dz,
                      step=0.0001,
                      title='+ Delta redshift')
    dzslider.format = "0[.]0000"

    #- Observer vs. Rest frame wavelengths
    waveframe_buttons = RadioButtonGroup(labels=["Obs", "Rest"], active=0)

    ifiberslider = Slider(start=0, end=nspec - 1, value=0, step=1)
    if frame_input:
        ifiberslider.title = 'Fiber'
    else:
        ifiberslider.title = 'Target'

    zslider_callback = CustomJS(
        args=dict(
            spectra=cds_spectra,
            model=cds_model,
            targetinfo=cds_targetinfo,
            ifiberslider=ifiberslider,
            zslider=zslider,
            dzslider=dzslider,
            waveframe_buttons=waveframe_buttons,
            line_data=line_data,
            lines=lines,
            line_labels=line_labels,
            fig=fig,
        ),
        #- TODO: reorder to reduce duplicated code
        code="""
        var z = zslider.value + dzslider.value
        var line_restwave = line_data.data['restwave']
        var ifiber = ifiberslider.value
        var zfit = 0.0
        if(targetinfo.data['z'] != undefined) {
            zfit = targetinfo.data['z'][ifiber]
        }

        // Observer Frame
        if(waveframe_buttons.active == 0) {
            var x = 0.0
            for(var i=0; i<line_restwave.length; i++) {
                x = line_restwave[i] * (1+z)
                lines[i].location = x
                line_labels[i].x = x
            }
            for(var i=0; i<spectra.length; i++) {
                var data = spectra[i].data
                var origwave = data['origwave']
                var plotwave = data['plotwave']
                for (var j=0; j<plotwave.length; j++) {
                    plotwave[j] = origwave[j]
                }
                spectra[i].change.emit()
            }

            // Update model wavelength array
            if(model) {
                var origwave = model.data['origwave']
                var plotwave = model.data['plotwave']
                for(var i=0; i<plotwave.length; i++) {
                    plotwave[i] = origwave[i] * (1+z) / (1+zfit)
                }
                model.change.emit()
            }

        // Rest Frame
        } else {
            for(i=0; i<line_restwave.length; i++) {
                lines[i].location = line_restwave[i]
                line_labels[i].x = line_restwave[i]
            }
            for (var i=0; i<spectra.length; i++) {
                var data = spectra[i].data
                var origwave = data['origwave']
                var plotwave = data['plotwave']
                for (var j=0; j<plotwave.length; j++) {
                    plotwave[j] = origwave[j] / (1+z)
                }
                spectra[i].change.emit()
            }

            // Update model wavelength array
            if(model) {
                var origwave = model.data['origwave']
                var plotwave = model.data['plotwave']
                for(var i=0; i<plotwave.length; i++) {
                    plotwave[i] = origwave[i] / (1+zfit)
                }
                model.change.emit()
            }
        }
        """)

    zslider.js_on_change('value', zslider_callback)
    dzslider.js_on_change('value', zslider_callback)
    waveframe_buttons.js_on_click(zslider_callback)

    plotrange_callback = CustomJS(args=dict(
        zslider=zslider,
        dzslider=dzslider,
        waveframe_buttons=waveframe_buttons,
        fig=fig,
    ),
                                  code="""
        var z = zslider.value + dzslider.value
        // Observer Frame
        if(waveframe_buttons.active == 0) {
            fig.x_range.start = fig.x_range.start * (1+z)
            fig.x_range.end = fig.x_range.end * (1+z)
        } else {
            fig.x_range.start = fig.x_range.start / (1+z)
            fig.x_range.end = fig.x_range.end / (1+z)
        }
        """)
    waveframe_buttons.js_on_click(plotrange_callback)

    smootherslider = Slider(start=0,
                            end=31,
                            value=0,
                            step=1.0,
                            title='Gaussian Sigma Smooth')
    target_info_div = Div(text=target_info[0])

    #-----
    #- Toggle lines
    lines_button_group = CheckboxButtonGroup(labels=["Emission", "Absorption"],
                                             active=[])

    lines_callback = CustomJS(args=dict(line_data=line_data,
                                        lines=lines,
                                        line_labels=line_labels),
                              code="""
        var show_emission = false
        var show_absorption = false
        if (cb_obj.active.indexOf(0) >= 0) {  // index 0=Emission in active list
            show_emission = true
        }
        if (cb_obj.active.indexOf(1) >= 0) {  // index 1=Absorption in active list
            show_absorption = true
        }

        for(var i=0; i<lines.length; i++) {
            if(line_data.data['emission'][i]) {
                lines[i].visible = show_emission
                line_labels[i].visible = show_emission
            } else {
                lines[i].visible = show_absorption
                line_labels[i].visible = show_absorption
            }
        }
        """)
    lines_button_group.js_on_click(lines_callback)
    # lines_button_group.js_on_change('value', lines_callback)

    #-----
    update_plot = CustomJS(args=dict(
        spectra=cds_spectra,
        model=cds_model,
        targetinfo=cds_targetinfo,
        target_info_div=target_info_div,
        ifiberslider=ifiberslider,
        smootherslider=smootherslider,
        zslider=zslider,
        dzslider=dzslider,
        lines_button_group=lines_button_group,
        fig=fig,
    ),
                           code="""
        var ifiber = ifiberslider.value
        var nsmooth = smootherslider.value
        target_info_div.text = targetinfo.data['target_info'][ifiber]

        if(targetinfo.data['z'] != undefined) {
            var z = targetinfo.data['z'][ifiber]
            var z1 = Math.floor(z*100) / 100
            zslider.value = z1
            dzslider.value = (z - z1)
        }

        function get_y_minmax(pmin, pmax, data) {
            // copy before sorting to not impact original, and filter out NaN
            var dx = data.slice().filter(Boolean)
            dx.sort()
            var imin = Math.floor(pmin * dx.length)
            var imax = Math.floor(pmax * dx.length)
            return [dx[imin], dx[imax]]
        }

        // Smoothing kernel
        var kernel = [];
        for(var i=-2*nsmooth; i<=2*nsmooth; i++) {
            kernel.push(Math.exp(-(i**2)/(2*nsmooth)))
        }
        var kernel_offset = Math.floor(kernel.length/2)

        // Smooth plot and recalculate ymin/ymax
        // TODO: add smoother function to reduce duplicated code
        var ymin = 0.0
        var ymax = 0.0
        for (var i=0; i<spectra.length; i++) {
            var data = spectra[i].data
            var plotflux = data['plotflux']
            var origflux = data['origflux'+ifiber]
            for (var j=0; j<plotflux.length; j++) {
                if(nsmooth == 0) {
                    plotflux[j] = origflux[j]
                } else {
                    plotflux[j] = 0.0
                    var weight = 0.0
                    // TODO: speed could be improved by moving `if` out of loop
                    for (var k=0; k<kernel.length; k++) {
                        var m = j+k-kernel_offset
                        if((m >= 0) && (m < plotflux.length)) {
                            var fx = origflux[m]
                            if(fx == fx) {
                                plotflux[j] = plotflux[j] + fx * kernel[k]
                                weight += kernel[k]
                            }
                        }
                    }
                    plotflux[j] = plotflux[j] / weight
                }
            }
            spectra[i].change.emit()

            tmp = get_y_minmax(0.01, 0.99, plotflux)
            ymin = Math.min(ymin, tmp[0])
            ymax = Math.max(ymax, tmp[1])
        }

        // update model
        if(model) {
            var plotflux = model.data['plotflux']
            var origflux = model.data['origflux'+ifiber]
            for (var j=0; j<plotflux.length; j++) {
                if(nsmooth == 0) {
                    plotflux[j] = origflux[j]
                } else {
                    plotflux[j] = 0.0
                    var weight = 0.0
                    // TODO: speed could be improved by moving `if` out of loop
                    for (var k=0; k<kernel.length; k++) {
                        var m = j+k-kernel_offset
                        if((m >= 0) && (m < plotflux.length)) {
                            var fx = origflux[m]
                            if(fx == fx) {
                                plotflux[j] = plotflux[j] + fx * kernel[k]
                                weight += kernel[k]
                            }
                        }
                    }
                    plotflux[j] = plotflux[j] / weight
                }
            }
            model.change.emit()
        }

        // update y_range
        if(ymin<0) {
            fig.y_range.start = ymin * 1.4
        } else {
            fig.y_range.start = ymin * 0.6
        }
        fig.y_range.end = ymax * 1.4
    """)
    smootherslider.js_on_change('value', update_plot)
    ifiberslider.js_on_change('value', update_plot)

    #-----
    #- Add navigation buttons
    navigation_button_width = 30
    prev_button = Button(label="<", width=navigation_button_width)
    next_button = Button(label=">", width=navigation_button_width)

    prev_callback = CustomJS(args=dict(ifiberslider=ifiberslider),
                             code="""
        if(ifiberslider.value>0) {
            ifiberslider.value--
        }
        """)
    next_callback = CustomJS(args=dict(ifiberslider=ifiberslider, nspec=nspec),
                             code="""
        if(ifiberslider.value<nspec+1) {
            ifiberslider.value++
        }
        """)

    prev_button.js_on_event('button_click', prev_callback)
    next_button.js_on_event('button_click', next_callback)

    #-----
    slider_width = plot_width - 2 * navigation_button_width
    navigator = bk.Row(
        widgetbox(prev_button, width=navigation_button_width),
        widgetbox(next_button, width=navigation_button_width + 20),
        widgetbox(ifiberslider, width=slider_width - 20))
    bk.show(
        bk.Column(
            bk.Row(fig, zoomfig),
            widgetbox(target_info_div, width=plot_width),
            navigator,
            widgetbox(smootherslider, width=plot_width // 2),
            bk.Row(
                widgetbox(waveframe_buttons, width=120),
                widgetbox(zslider, width=plot_width // 2 - 60),
                widgetbox(dzslider, width=plot_width // 2 - 60),
            ),
            widgetbox(lines_button_group),
        ))
Exemple #27
0
p_crd.add_tools(draw_tool)
p_crd.toolbar.active_tap = draw_tool


slider_sx = Slider(start=0, end=5, value=5, step=.1, title="Persistency", sizing_mode='stretch_width')
slider_sx.js_on_change('value',
                      CustomJS(args=dict(box=box, source_sx=source_sx, source_sx_persistency=source_sx_persistency),
    code="""
    box.bottom= cb_obj.value;
    var d1 = source_sx.data;
    const d2 = {'index': [], 'sx_value': [],
                'start_datetime_str': [], 'end_datetime_str' : [], 
                'start_datetime_dt':  [], 'end_datetime_dt'  : []};
    for (var i = 0; i < d1['index'].length; i++) {
        if (d1['sx_value'][i] > cb_obj.value) {
                d2['index'].push(             d1['index'][i]);
                d2['sx_value'].push(          d1['sx_value'][i]);
                d2['start_datetime_str'].push(d1['start_datetime_str'][i]);
                d2['end_datetime_str'].push(  d1['end_datetime_str'][i]);
                d2['start_datetime_dt'].push( d1['start_datetime_dt'][i]);
                d2['end_datetime_dt'].push(   d1['end_datetime_dt'][i]);
            }
    }
    source_sx_persistency.data = d2;
    console.log(d2);
    source_sx_persistency.change.emit();
    """))



toggle_sx = Toggle(label="Persistence", button_type="success", active=True)
toggle1 = Toggle(label="Word", button_type="success", active=True)
def plot_3DModel_bokeh(filename, map_data_all_slices, map_depth_all_slices, \
                       color_range_all_slices, profile_data_all, boundary_data, \
                       style_parameter):
    '''
    Plot shear velocity maps and velocity profiles using bokeh

    Input:
        filename is the filename of the resulting html file
        map_data_all_slices contains the velocity model parameters saved for map view plots
        map_depth_all_slices is a list of depths
        color_range_all_slices is a list of color ranges
        profile_data_all constains the velocity model parameters saved for profile plots
        boundary_data is a list of boundaries
        style_parameter contains plotting parameters

    Output:
        None
    
    '''
    xlabel_fontsize = style_parameter['xlabel_fontsize']
    #
    colorbar_data_all_left = []
    colorbar_data_all_right = []
    map_view_ndepth = style_parameter['map_view_ndepth']
    ncolor = len(palette)
    colorbar_top = [0.1 for i in range(ncolor)]
    colorbar_bottom = [0 for i in range(ncolor)]
    map_data_all_slices_depth = []
    for idepth in range(map_view_ndepth):
        color_min = color_range_all_slices[idepth][0]
        color_max = color_range_all_slices[idepth][1]
        color_step = (color_max - color_min) * 1. / ncolor
        colorbar_left = np.linspace(color_min, color_max - color_step, ncolor)
        colorbar_right = np.linspace(color_min + color_step, color_max, ncolor)
        colorbar_data_all_left.append(colorbar_left)
        colorbar_data_all_right.append(colorbar_right)
        map_depth = map_depth_all_slices[idepth]
        map_data_all_slices_depth.append(
            'Depth: {0:8.1f} km'.format(map_depth))
    #
    palette_r = palette[::-1]
    # data for the colorbar
    colorbar_data_one_slice = {}
    colorbar_data_one_slice['colorbar_left'] = colorbar_data_all_left[
        style_parameter['map_view_default_index']]
    colorbar_data_one_slice['colorbar_right'] = colorbar_data_all_right[
        style_parameter['map_view_default_index']]
    colorbar_data_one_slice_bokeh = ColumnDataSource(data=dict(colorbar_top=colorbar_top,colorbar_bottom=colorbar_bottom,\
                                                               colorbar_left=colorbar_data_one_slice['colorbar_left'],\
                                                               colorbar_right=colorbar_data_one_slice['colorbar_right'],\
                                                               palette_r=palette_r))
    colorbar_data_all_slices_bokeh = ColumnDataSource(data=dict(colorbar_data_all_left=colorbar_data_all_left,\
                                                                colorbar_data_all_right=colorbar_data_all_right))
    #
    map_view_label_lon = style_parameter['map_view_depth_label_lon']
    map_view_label_lat = style_parameter['map_view_depth_label_lat']
    map_data_one_slice_depth = map_data_all_slices_depth[
        style_parameter['map_view_default_index']]
    map_data_one_slice_depth_bokeh = ColumnDataSource(data=dict(lat=[map_view_label_lat], lon=[map_view_label_lon],
                                                           map_depth=[map_data_one_slice_depth],
                                                           left=[style_parameter['profile_plot_xmin']], \
                                                           right=[style_parameter['profile_plot_xmax']]))

    #
    map_view_default_index = style_parameter['map_view_default_index']
    #map_data_one_slice = map_data_all_slices[map_view_default_index]
    #
    map_color_all_slices = []
    for i in range(len(map_data_all_slices)):
        vmin, vmax = color_range_all_slices[i]
        map_color = val_to_rgb(map_data_all_slices[i], palette_r, vmin, vmax)
        map_color_2d = map_color.view('uint32').reshape(map_color.shape[:2])
        map_color_all_slices.append(map_color_2d)
    map_color_one_slice = map_color_all_slices[map_view_default_index]
    #
    map_data_one_slice_bokeh = ColumnDataSource(data=dict(x=[style_parameter['map_view_image_lon_min']],\
                   y=[style_parameter['map_view_image_lat_min']],dw=[style_parameter['nlon']*style_parameter['dlon']],\
                   dh=[style_parameter['nlat']*style_parameter['dlat']],map_data_one_slice=[map_color_one_slice]))
    map_data_all_slices_bokeh = ColumnDataSource(data=dict(map_data_all_slices=map_color_all_slices,\
                                                           map_data_all_slices_depth=map_data_all_slices_depth))
    # ------------------------------
    nprofile = len(profile_data_all)
    grid_lat_list = []
    grid_lon_list = []
    width_list = []
    height_list = []
    for iprofile in range(nprofile):
        aprofile = profile_data_all[iprofile]
        grid_lat_list.append(aprofile['lat'])
        grid_lon_list.append(aprofile['lon'])
        width_list.append(style_parameter['map_view_grid_width'])
        height_list.append(style_parameter['map_view_grid_height'])
    grid_data_bokeh = ColumnDataSource(data=dict(lon=grid_lon_list,lat=grid_lat_list,\
                                            width=width_list, height=height_list))
    profile_default_index = style_parameter['profile_default_index']
    selected_dot_on_map_bokeh = ColumnDataSource(data=dict(lat=[grid_lat_list[profile_default_index]], \
                                                           lon=[grid_lon_list[profile_default_index]], \
                                                           width=[style_parameter['map_view_grid_width']],\
                                                           height=[style_parameter['map_view_grid_height']],\
                                                           index=[profile_default_index]))
    # ------------------------------
    profile_vs_all = []
    profile_depth_all = []
    profile_ndepth = style_parameter['profile_ndepth']
    profile_lat_label_list = []
    profile_lon_label_list = []
    for iprofile in range(nprofile):
        aprofile = profile_data_all[iprofile]
        vs_raw = aprofile['vs']
        top_raw = aprofile['top']
        profile_lat_label_list.append('Lat: {0:12.1f}'.format(aprofile['lat']))
        profile_lon_label_list.append('Lon: {0:12.1f}'.format(aprofile['lon']))
        vs_plot = []
        depth_plot = []
        for idepth in range(profile_ndepth):
            vs_plot.append(vs_raw[idepth])
            depth_plot.append(top_raw[idepth])
            vs_plot.append(vs_raw[idepth])
            depth_plot.append(top_raw[idepth + 1])
        profile_vs_all.append(vs_plot)
        profile_depth_all.append(depth_plot)
    profile_data_all_bokeh = ColumnDataSource(data=dict(profile_vs_all=profile_vs_all, \
                                                        profile_depth_all=profile_depth_all))
    selected_profile_data_bokeh = ColumnDataSource(data=dict(vs=profile_vs_all[profile_default_index],\
                                                             depth=profile_depth_all[profile_default_index]))
    selected_profile_lat_label_bokeh = ColumnDataSource(data=\
                                dict(x=[style_parameter['profile_lat_label_x']], y=[style_parameter['profile_lat_label_y']],\
                                    lat_label=[profile_lat_label_list[profile_default_index]]))
    selected_profile_lon_label_bokeh = ColumnDataSource(data=\
                                dict(x=[style_parameter['profile_lon_label_x']], y=[style_parameter['profile_lon_label_y']],\
                                    lon_label=[profile_lon_label_list[profile_default_index]]))
    all_profile_lat_label_bokeh = ColumnDataSource(data=dict(
        profile_lat_label_list=profile_lat_label_list))
    all_profile_lon_label_bokeh = ColumnDataSource(data=dict(
        profile_lon_label_list=profile_lon_label_list))
    #
    button_ndepth = style_parameter['button_ndepth']
    button_data_all_vs = []
    button_data_all_vp = []
    button_data_all_rho = []
    button_data_all_top = []
    for iprofile in range(nprofile):
        aprofile = profile_data_all[iprofile]
        button_data_all_vs.append(aprofile['vs'][:button_ndepth])
        button_data_all_vp.append(aprofile['vp'][:button_ndepth])
        button_data_all_rho.append(aprofile['rho'][:button_ndepth])
        button_data_all_top.append(aprofile['top'][:button_ndepth])
    button_data_all_bokeh = ColumnDataSource(data=dict(button_data_all_vs=button_data_all_vs,\
                                                       button_data_all_vp=button_data_all_vp,\
                                                       button_data_all_rho=button_data_all_rho,\
                                                       button_data_all_top=button_data_all_top))
    # ==============================
    map_view = Figure(plot_width=style_parameter['map_view_plot_width'], plot_height=style_parameter['map_view_plot_height'], \
                      tools=style_parameter['map_view_tools'], title=style_parameter['map_view_title'], \
                      y_range=[style_parameter['map_view_figure_lat_min'], style_parameter['map_view_figure_lat_max']],\
                      x_range=[style_parameter['map_view_figure_lon_min'], style_parameter['map_view_figure_lon_max']])
    #
    map_view.image_rgba('map_data_one_slice',x='x',\
                   y='y',dw='dw',dh='dh',
                   source=map_data_one_slice_bokeh, level='image')

    depth_slider_callback = CustomJS(args=dict(map_data_one_slice_bokeh=map_data_one_slice_bokeh,\
                                               map_data_all_slices_bokeh=map_data_all_slices_bokeh,\
                                               colorbar_data_all_slices_bokeh=colorbar_data_all_slices_bokeh,\
                                               colorbar_data_one_slice_bokeh=colorbar_data_one_slice_bokeh,\
                                               map_data_one_slice_depth_bokeh=map_data_one_slice_depth_bokeh), code="""

        var d_index = Math.round(cb_obj.value)
        
        var map_data_all_slices = map_data_all_slices_bokeh.data
        
        map_data_one_slice_bokeh.data['map_data_one_slice'] = [map_data_all_slices['map_data_all_slices'][d_index]]
        map_data_one_slice_bokeh.change.emit()
        
        var color_data_all_slices = colorbar_data_all_slices_bokeh.data
        colorbar_data_one_slice_bokeh.data['colorbar_left'] = color_data_all_slices['colorbar_data_all_left'][d_index]
        colorbar_data_one_slice_bokeh.data['colorbar_right'] = color_data_all_slices['colorbar_data_all_right'][d_index]
        colorbar_data_one_slice_bokeh.change.emit()
        
        map_data_one_slice_depth_bokeh.data['map_depth'] = [map_data_all_slices['map_data_all_slices_depth'][d_index]]
        map_data_one_slice_depth_bokeh.change.emit()
        
    """)
    depth_slider = Slider(start=0, end=style_parameter['map_view_ndepth']-1, \
                          value=map_view_default_index, step=1, \
                          width=style_parameter['map_view_plot_width'],\
                          title=style_parameter['depth_slider_title'], height=50)
    depth_slider.js_on_change('value', depth_slider_callback)
    depth_slider_callback.args["depth_index"] = depth_slider
    # ------------------------------
    # add boundaries to map view
    # country boundaries
    map_view.multi_line(boundary_data['country']['longitude'],\
                        boundary_data['country']['latitude'],color='black',\
                        line_width=2, level='underlay',nonselection_line_alpha=1.0,\
                        nonselection_line_color='black')
    # marine boundaries
    map_view.multi_line(boundary_data['marine']['longitude'],\
                        boundary_data['marine']['latitude'],color='black',\
                        level='underlay',nonselection_line_alpha=1.0,\
                        nonselection_line_color='black')
    # shoreline boundaries
    map_view.multi_line(boundary_data['shoreline']['longitude'],\
                        boundary_data['shoreline']['latitude'],color='black',\
                        line_width=2, level='underlay',nonselection_line_alpha=1.0,\
                        nonselection_line_color='black')
    # state boundaries
    map_view.multi_line(boundary_data['state']['longitude'],\
                        boundary_data['state']['latitude'],color='black',\
                        level='underlay',nonselection_line_alpha=1.0,\
                        nonselection_line_color='black')
    # ------------------------------
    # add depth label
    map_view.rect(style_parameter['map_view_depth_box_lon'], style_parameter['map_view_depth_box_lat'], \
                  width=style_parameter['map_view_depth_box_width'], height=style_parameter['map_view_depth_box_height'], \
                  width_units='screen',height_units='screen', color='#FFFFFF', line_width=1., line_color='black', level='underlay')
    map_view.text('lon', 'lat', 'map_depth', source=map_data_one_slice_depth_bokeh,\
                  text_font_size=style_parameter['annotating_text_font_size'],text_align='left',level='underlay')
    # ------------------------------
    map_view.rect('lon', 'lat', width='width', \
                  width_units='screen', height='height', \
                  height_units='screen', line_color='gray', line_alpha=0.5, \
                  selection_line_color='gray', selection_line_alpha=0.5, selection_fill_color=None,\
                  nonselection_line_color='gray',nonselection_line_alpha=0.5, nonselection_fill_color=None,\
                  source=grid_data_bokeh, color=None, line_width=1, level='glyph')
    map_view.rect('lon', 'lat',width='width', \
                  width_units='screen', height='height', \
                  height_units='screen', line_color='#00ff00', line_alpha=1.0, \
                  source=selected_dot_on_map_bokeh, fill_color=None, line_width=3.,level='glyph')
    # ------------------------------
    grid_data_js = CustomJS(args=dict(selected_dot_on_map_bokeh=selected_dot_on_map_bokeh, \
                                                  grid_data_bokeh=grid_data_bokeh,\
                                                  profile_data_all_bokeh=profile_data_all_bokeh,\
                                                  selected_profile_data_bokeh=selected_profile_data_bokeh,\
                                                  selected_profile_lat_label_bokeh=selected_profile_lat_label_bokeh,\
                                                  selected_profile_lon_label_bokeh=selected_profile_lon_label_bokeh, \
                                                  all_profile_lat_label_bokeh=all_profile_lat_label_bokeh, \
                                                  all_profile_lon_label_bokeh=all_profile_lon_label_bokeh, \
                                                 ), code="""
        
        var inds = cb_obj.indices
        
        var grid_data = grid_data_bokeh.data
        selected_dot_on_map_bokeh.data['lat'] = [grid_data['lat'][inds]]
        selected_dot_on_map_bokeh.data['lon'] = [grid_data['lon'][inds]]
        selected_dot_on_map_bokeh.data['index'] = [inds]
        selected_dot_on_map_bokeh.change.emit()
        
        var profile_data_all = profile_data_all_bokeh.data
        selected_profile_data_bokeh.data['vs'] = profile_data_all['profile_vs_all'][inds]
        selected_profile_data_bokeh.data['depth'] = profile_data_all['profile_depth_all'][inds]
        selected_profile_data_bokeh.change.emit()
        
        var all_profile_lat_label = all_profile_lat_label_bokeh.data['profile_lat_label_list']
        var all_profile_lon_label = all_profile_lon_label_bokeh.data['profile_lon_label_list']
        selected_profile_lat_label_bokeh.data['lat_label'] = [all_profile_lat_label[inds]]
        selected_profile_lon_label_bokeh.data['lon_label'] = [all_profile_lon_label[inds]]
        selected_profile_lat_label_bokeh.change.emit()
        selected_profile_lon_label_bokeh.change.emit()
    """)
    grid_data_bokeh.selected.js_on_change('indices', grid_data_js)
    # ------------------------------
    # change style
    map_view.title.text_font_size = style_parameter['title_font_size']
    map_view.title.align = 'center'
    map_view.title.text_font_style = 'normal'
    map_view.xaxis.axis_label = style_parameter['map_view_xlabel']
    map_view.xaxis.axis_label_text_font_style = 'normal'
    map_view.xaxis.axis_label_text_font_size = xlabel_fontsize
    map_view.xaxis.major_label_text_font_size = xlabel_fontsize
    map_view.yaxis.axis_label = style_parameter['map_view_ylabel']
    map_view.yaxis.axis_label_text_font_style = 'normal'
    map_view.yaxis.axis_label_text_font_size = xlabel_fontsize
    map_view.yaxis.major_label_text_font_size = xlabel_fontsize
    map_view.xgrid.grid_line_color = None
    map_view.ygrid.grid_line_color = None
    map_view.toolbar.logo = None
    map_view.toolbar_location = 'above'
    map_view.toolbar_sticky = False
    # ==============================
    # plot colorbar
    colorbar_fig = Figure(tools=[], y_range=(0,0.1),plot_width=style_parameter['map_view_plot_width'], \
                      plot_height=style_parameter['colorbar_plot_height'],title=style_parameter['colorbar_title'])
    colorbar_fig.toolbar_location = None
    colorbar_fig.quad(top='colorbar_top',bottom='colorbar_bottom',left='colorbar_left',right='colorbar_right',\
                  color='palette_r',source=colorbar_data_one_slice_bokeh)
    colorbar_fig.yaxis[0].ticker = FixedTicker(ticks=[])
    colorbar_fig.xgrid.grid_line_color = None
    colorbar_fig.ygrid.grid_line_color = None
    colorbar_fig.xaxis.axis_label_text_font_size = xlabel_fontsize
    colorbar_fig.xaxis.major_label_text_font_size = xlabel_fontsize
    colorbar_fig.xaxis[0].formatter = PrintfTickFormatter(format="%5.2f")
    colorbar_fig.title.text_font_size = xlabel_fontsize
    colorbar_fig.title.align = 'center'
    colorbar_fig.title.text_font_style = 'normal'
    # ==============================
    profile_xrange = Range1d(start=style_parameter['profile_plot_xmin'],
                             end=style_parameter['profile_plot_xmax'])
    profile_yrange = Range1d(start=style_parameter['profile_plot_ymax'],
                             end=style_parameter['profile_plot_ymin'])
    profile_fig = Figure(plot_width=style_parameter['profile_plot_width'], plot_height=style_parameter['profile_plot_height'],\
                         x_range=profile_xrange, y_range=profile_yrange, tools=style_parameter['profile_tools'],\
                         title=style_parameter['profile_title'])
    profile_fig.line('vs',
                     'depth',
                     source=selected_profile_data_bokeh,
                     line_width=2,
                     line_color='black')
    # ------------------------------
    # add lat, lon
    profile_fig.rect([style_parameter['profile_label_box_x']], [style_parameter['profile_label_box_y']],\
                     width=style_parameter['profile_label_box_width'], height=style_parameter['profile_label_box_height'],\
                     width_units='screen', height_units='screen', color='#FFFFFF', line_width=1., line_color='black',\
                     level='underlay')
    profile_fig.text('x',
                     'y',
                     'lat_label',
                     source=selected_profile_lat_label_bokeh)
    profile_fig.text('x',
                     'y',
                     'lon_label',
                     source=selected_profile_lon_label_bokeh)
    # ------------------------------
    # change style
    profile_fig.xaxis.axis_label = style_parameter['profile_xlabel']
    profile_fig.xaxis.axis_label_text_font_style = 'normal'
    profile_fig.xaxis.axis_label_text_font_size = xlabel_fontsize
    profile_fig.xaxis.major_label_text_font_size = xlabel_fontsize
    profile_fig.yaxis.axis_label = style_parameter['profile_ylabel']
    profile_fig.yaxis.axis_label_text_font_style = 'normal'
    profile_fig.yaxis.axis_label_text_font_size = xlabel_fontsize
    profile_fig.yaxis.major_label_text_font_size = xlabel_fontsize
    profile_fig.xgrid.grid_line_dash = [4, 2]
    profile_fig.ygrid.grid_line_dash = [4, 2]
    profile_fig.title.text_font_size = style_parameter['title_font_size']
    profile_fig.title.align = 'center'
    profile_fig.title.text_font_style = 'normal'
    profile_fig.toolbar_location = 'above'
    profile_fig.toolbar_sticky = False
    profile_fig.toolbar.logo = None
    # ==============================
    profile_slider_callback = CustomJS(args=dict(selected_dot_on_map_bokeh=selected_dot_on_map_bokeh,\
                                                 grid_data_bokeh=grid_data_bokeh, \
                                                 profile_data_all_bokeh=profile_data_all_bokeh, \
                                                 selected_profile_data_bokeh=selected_profile_data_bokeh,\
                                                 selected_profile_lat_label_bokeh=selected_profile_lat_label_bokeh,\
                                                 selected_profile_lon_label_bokeh=selected_profile_lon_label_bokeh, \
                                                 all_profile_lat_label_bokeh=all_profile_lat_label_bokeh, \
                                                 all_profile_lon_label_bokeh=all_profile_lon_label_bokeh), code="""
        var p_index = Math.round(cb_obj.value)
        
        var grid_data = grid_data_bokeh.data
        selected_dot_on_map_bokeh.data['lat'] = [grid_data['lat'][p_index]]
        selected_dot_on_map_bokeh.data['lon'] = [grid_data['lon'][p_index]]
        selected_dot_on_map_bokeh.data['index'] = [p_index]
        selected_dot_on_map_bokeh.change.emit()
        
        var profile_data_all = profile_data_all_bokeh.data
        selected_profile_data_bokeh.data['vs'] = profile_data_all['profile_vs_all'][p_index]
        selected_profile_data_bokeh.data['depth'] = profile_data_all['profile_depth_all'][p_index]
        selected_profile_data_bokeh.change.emit()
        
        var all_profile_lat_label = all_profile_lat_label_bokeh.data['profile_lat_label_list']
        var all_profile_lon_label = all_profile_lon_label_bokeh.data['profile_lon_label_list']
        selected_profile_lat_label_bokeh.data['lat_label'] = [all_profile_lat_label[p_index]]
        selected_profile_lon_label_bokeh.data['lon_label'] = [all_profile_lon_label[p_index]]
        selected_profile_lat_label_bokeh.change.emit()
        selected_profile_lon_label_bokeh.change.emit()
        
    """)
    profile_slider = Slider(start=0, end=nprofile-1, value=style_parameter['profile_default_index'], \
                           step=1, title=style_parameter['profile_slider_title'], \
                           width=style_parameter['profile_plot_width'], height=50)
    profile_slider_callback.args['profile_index'] = profile_slider
    profile_slider.js_on_change('value', profile_slider_callback)
    # ==============================
    simple_text_button_callback = CustomJS(args=dict(button_data_all_bokeh=button_data_all_bokeh,\
                                                    selected_dot_on_map_bokeh=selected_dot_on_map_bokeh), \
                                           code="""
        var index = selected_dot_on_map_bokeh.data['index']
        
        var button_data_vs = button_data_all_bokeh.data['button_data_all_vs'][index]
        var button_data_vp = button_data_all_bokeh.data['button_data_all_vp'][index]
        var button_data_rho = button_data_all_bokeh.data['button_data_all_rho'][index]
        var button_data_top = button_data_all_bokeh.data['button_data_all_top'][index]
        
        var csvContent = ""
        var i = 0
        var temp = csvContent
        temp += "# Layer Top (km)      Vs(km/s)    Vp(km/s)    Rho(g/cm^3) \\n"
        while(button_data_vp[i]) {
            temp+=button_data_top[i].toPrecision(6) + "    " + button_data_vs[i].toPrecision(4) + "   " + \
                    button_data_vp[i].toPrecision(4) + "   " + button_data_rho[i].toPrecision(4) + "\\n"
            i = i + 1
        }
        const blob = new Blob([temp], { type: 'text/csv;charset=utf-8;' })
        const link = document.createElement('a');
        link.href = URL.createObjectURL(blob);
        link.download = 'vel_model.txt';
        link.target = '_blank'
        link.style.visibility = 'hidden'
        link.dispatchEvent(new MouseEvent('click'))
        
    """)

    simple_text_button = Button(
        label=style_parameter['simple_text_button_label'],
        button_type='default',
        width=style_parameter['button_width'])
    simple_text_button.js_on_click(simple_text_button_callback)
    # ------------------------------
    model96_button_callback = CustomJS(args=dict(button_data_all_bokeh=button_data_all_bokeh,\
                                                    selected_dot_on_map_bokeh=selected_dot_on_map_bokeh), \
                                           code="""
        var index = selected_dot_on_map_bokeh.data['index']
        var lat = selected_dot_on_map_bokeh.data['lat']
        var lon = selected_dot_on_map_bokeh.data['lon']
        
        var button_data_vs = button_data_all_bokeh.data['button_data_all_vs'][index]
        var button_data_vp = button_data_all_bokeh.data['button_data_all_vp'][index]
        var button_data_rho = button_data_all_bokeh.data['button_data_all_rho'][index]
        var button_data_top = button_data_all_bokeh.data['button_data_all_top'][index]
        
        var csvContent = ""
        var i = 0
        var temp = csvContent
        temp +=  "MODEL." + index + " \\n"
        temp +=  "ShearVelocityModel Lat: "+ lat +"  Lon: " + lon + "\\n"
        temp +=  "ISOTROPIC \\n"
        temp +=  "KGS \\n"
        temp +=  "SPHERICAL EARTH \\n"
        temp +=  "1-D \\n"
        temp +=  "CONSTANT VELOCITY \\n"
        temp +=  "LINE08 \\n"
        temp +=  "LINE09 \\n"
        temp +=  "LINE10 \\n"
        temp +=  "LINE11 \\n"
        temp +=  "      H(KM)   VP(KM/S)   VS(KM/S) RHO(GM/CC)     QP         QS       ETAP       ETAS      FREFP      FREFS \\n"
        while(button_data_vp[i+1]) {
            var thickness = button_data_top[i+1] - button_data_top[i]
            temp+="      " +thickness.toPrecision(6) + "    " + button_data_vp[i].toPrecision(4) + "      " + button_data_vs[i].toPrecision(4) \
                 + "      " + button_data_rho[i].toPrecision(4) + "     0.00       0.00       0.00       0.00       1.00       1.00" + "\\n"
            i = i + 1
        }
        const blob = new Blob([temp], { type: 'text/csv;charset=utf-8;' })
        const link = document.createElement('a');
        link.href = URL.createObjectURL(blob);
        link.download = 'vel_model96.txt';
        link.target = '_blank'
        link.style.visibility = 'hidden'
        link.dispatchEvent(new MouseEvent('click'))
    """)
    model96_button = Button(label=style_parameter['model96_button_label'],
                            button_type='default',
                            width=style_parameter['button_width'])
    model96_button.js_on_click(model96_button_callback)
    # ==============================
    # annotating text
    annotating_fig01 = Div(text=style_parameter['annotating_html01'], \
        width=style_parameter['annotation_plot_width'], height=style_parameter['annotation_plot_height'])
    annotating_fig02 = Div(text=style_parameter['annotating_html02'],\
        width=style_parameter['annotation_plot_width'], height=style_parameter['annotation_plot_height'])
    # ==============================
    output_file(filename,
                title=style_parameter['html_title'],
                mode=style_parameter['library_source'])
    left_column = Column(depth_slider,
                         map_view,
                         colorbar_fig,
                         annotating_fig01,
                         width=style_parameter['left_column_width'])
    button_pannel = Row(simple_text_button, model96_button)
    right_column = Column(profile_slider,
                          profile_fig,
                          button_pannel,
                          annotating_fig02,
                          width=style_parameter['right_column_width'])
    layout = Row(left_column, right_column)
    save(layout)
Exemple #29
0
    def getHeatMap(results, parameters, bgcolor, contextdict):

        Yname = parameters[2]
        Xname = parameters[3]

        Xdict = BokehHeatmap.getDataParts(Xname, results)
        Ydict = BokehHeatmap.getDataParts(Yname, results)

        #colours row wise
        factorsx = sorted(list(Xdict))
        factorsy = sorted(list(Ydict), reverse=True)

        rowlen = len(factorsx)
        collen = len(factorsy)

        countmatrix, maxval, observations = BokehHeatmap.getCountMatrix(
            Xname, Yname, factorsx, factorsy, results, bokutils.FIRST_YEAR,
            bokutils.LAST_YEAR)
        rate = []
        LOWVALUE = 1
        HIGHVALUE = int(float(maxval) + float(0.5)) - 1
        COLOURBAR = True
        hcolors = bokutils.generateColorGradientHex(
            bokutils.hex_to_rgb(bgcolor), (255, 204, 204), 4)
        hcolors = ["#0C7D03", "#489D42", "#85BE81", "#C2DEC0"]
        textcolors = hcolors[::-1]

        for i, items in enumerate(factorsx):
            factorsx[i] = bokutils.makeLegendKey(items)
        for i, items in enumerate(factorsy):
            factorsy[i] = bokutils.makeLegendKey(items)
        x = []
        y = []
        totcount = 0
        # Initialise with last year as base
        for col in range(collen):
            for row in range(rowlen):
                x.append(factorsx[row])
                y.append(factorsy[col])
                count = int(countmatrix[-1][col][row] + float(0.5))
                rate.append(count)
                totcount = totcount + count

        if (totcount == 0):
            COLOURBAR = False
            HIGHVALUE = float(2)
            LOWVALUE = float(0.8)
        elif (HIGHVALUE == -1):
            COLOURBAR = False
            HIGHVALUE = 1
        elif (HIGHVALUE == 0):
            COLOURBAR = False
            HIGHVALUE = float(0.9)
            LOWVALUE = float(0.8)
        elif (HIGHVALUE == 1):
            HIGHVALUE = float(1.9)
            LOWVALUE = float(0.8)
            textcolors = ["#C2DEC0"]
            hcolors = ["#0C7D03"]
        elif (HIGHVALUE <= LOWVALUE):
            HIGHVALUE = float(HIGHVALUE) - 0.1
            LOWVALUE = float(HIGHVALUE - 0.1)
            COLOURBAR = False

        mapper = LinearColorMapper(palette=textcolors,
                                   high=float(HIGHVALUE),
                                   low=float(LOWVALUE),
                                   low_color="white",
                                   high_color="#45334C")
        source = ColumnDataSource(data=dict(x=x, y=y, rate=rate))
        allsource = ColumnDataSource(data=dict(ratebyyear=countmatrix))

        textmapper = LinearColorMapper(palette=hcolors[:], high=maxval)

        TOOLS = "hover,save"

        if (str(type(Xname)) == "<type 'str'>"):
            DXname = Xname
        else:
            DXname = Xname[0]
        if (str(type(Yname)) == "<type 'str'>"):
            DYname = Yname
        else:
            DYname = Yname[0]

        if (DXname == bokutils.PLOT_GLOCATION_KEY
                or DXname == bokutils.PLOT_GLOCATION_KEY):
            location = parameters[5]
            title = "Visualisations/Plot/X = " + location[
                1:] + " Y = " + "/".join(Yname).replace("_", " ")
        else:
            location = ", location=" + str(
                contextdict[bokutils.PLOT_LOCATION_KEY][0].replace("'", ""))
            title = "Visualisations/Plot/X = " + "/".join(Xname).replace(
                "_", " ") + " Y = " + "/".join(Yname).replace("_", " ")

#Visualisations/Plot/X = Governance/Independent Y = Classification 2018

        p = figure(title=title,
                   x_range=factorsx,
                   y_range=factorsy,
                   x_axis_location="above",
                   plot_width=900,
                   plot_height=900,
                   tools=TOOLS,
                   toolbar_location='below')
        #		   border_fill_color=hcolors[0])

        p.grid.grid_line_color = None
        p.axis.axis_line_color = None
        p.axis.major_tick_line_color = None
        p.axis.major_label_text_font_size = "10pt"
        p.axis.major_label_standoff = 0
        p.xaxis.major_label_orientation = pi / 3

        labels = LabelSet(x='x',
                          y='y',
                          text='rate',
                          level='glyph',
                          text_color={
                              'field': 'rate',
                              'transform': textmapper
                          },
                          x_offset=-10,
                          y_offset=-10,
                          source=source,
                          render_mode='canvas')
        p.rect(x="x",
               y="y",
               width=1,
               height=1,
               source=source,
               fill_color={
                   'field': 'rate',
                   'transform': mapper
               },
               line_color="white")

        if (COLOURBAR):
            color_bar = ColorBar(
                color_mapper=mapper,
                major_label_text_font_size="10pt",
                ticker=BasicTicker(desired_num_ticks=len(hcolors)),
                formatter=PrintfTickFormatter(format="%d"),
                label_standoff=6,
                border_line_color=hcolors[0],
                location=(0, 0))
            p.add_layout(color_bar, 'right')

        p.add_layout(Title(text="X", align="center"), "below")
        p.add_layout(Title(text="Y", align="center"), "left")
        p.add_layout(Title(text="Colour to count legend", align="center"),
                     "right")

        p.add_layout(labels)
        p.select_one(HoverTool).tooltips = [
            ('Point', '@y # @x'),
            ('Count', '@rate '),
        ]

        slider = Slider(start=bokutils.FIRST_YEAR,
                        end=bokutils.LAST_YEAR,
                        value=bokutils.LAST_YEAR,
                        step=1,
                        title='Year',
                        bar_color=hcolors[0])

        paramsource = ColumnDataSource(data=dict(params=[rowlen, collen]))

        callback = CustomJS(args=dict(source=source,
                                      slider=slider,
                                      plot=p,
                                      window=None,
                                      source2=allsource,
                                      source3=paramsource),
                            code="""
        var cb;
        cb = function (source,
                       slider,
                       plot,
                       window,
                       source2,
                       source3)
        {
          var al, arr, data, end, i;
          source = (source === undefined) ? source: source;
          //console.log("slider "+slider);
          slider = (slider === undefined) ? slider: slider;
          plot = (plot === undefined) ? p: plot;
          window = (window === undefined) ? null: window;
          data = source.data;
          
          var arr = source.data["rate"];
          //console.log("rate"+arr);
          var allcounts=source2.data["ratebyyear"];
          //console.log(allcounts);
          var params=source3.data["params"];
          var rowlen=params[0];
          var collen=params[1];
          //console.log("arrlen"+arr.length);
          //console.log("rowlen"+rowlen);
          //console.log("collen"+collen);

          var startidx=slider.value - 1960;
          //console.log("START"+startidx);
          var i=0;
          var j=0;
          var tot=0;

          while (j < collen)
            {
             while (i < rowlen)
              {
                arr[tot] = Math.round(allcounts[startidx][j][i]);
                i = i + 1;
                tot=tot+1;
              }
              j = j + 1;
              i=0;
            }




        //console.log("TOT="+tot);
        //console.log("rate"+arr);
        source.change.emit();

        return null;
        };
        cb(source, slider, plot,window,source2,source3);

        """)

        slider.js_on_change('value', callback)

        wb = widgetbox(children=[slider], sizing_mode='scale_width')
        thisrow = Column(children=[p, wb], sizing_mode='scale_both')

        return thisrow
Exemple #30
0
def generate_performance_plot(hds, hcols):
	aux_cols = ['Model', "Number_of_Drives", "Percent_of_Drives", "Color"]
	cols = ['Failure_Rate', 'Capacity', 'Interface', 'Cache', 'RPM', 'Price_GB']

	data = {}
	for c in cols:
		#print('col: ', c)
		data[c] = hds[c]
	for c in aux_cols:
		#print('col: ', c)
		data[c] = hds[c]

	max_scale = 1.0
	min_scale = 0.0

	max_cache = np.max(hds["Cache"])
	min_cache = np.min(hds["Cache"])
	cache = ( (max_scale - min_scale) /(max_cache - min_cache))*(hds["Cache"] - max_cache) + max_scale
	max_rpm = np.max(hds["RPM"])
	min_rpm = np.min(hds["RPM"])
	rpm = ( (max_scale - min_scale) /(max_rpm - min_rpm))*(hds["RPM"] - max_rpm) + max_scale
	max_interface = np.max(hds["Interface"])
	min_interface = np.min(hds["Interface"])
	interface = ( (max_scale - min_scale) /(max_interface - min_interface))*(hds["Interface"] - max_interface) + max_scale
	performance = interface + rpm + cache
	max_performance = np.max(performance)
	min_performance = np.min(performance)
	performance = ( (max_scale - min_scale) /(max_performance - min_performance))*(performance - max_performance) + max_scale
	max_failure = np.max(hds["Failure_Rate"])
	min_failure = np.min(hds["Failure_Rate"])
	reliability = ( (max_scale - min_scale) /(max_failure - min_failure))*(hds["Failure_Rate"] - max_failure) + max_scale
	reliability = 1.0 - reliability

	max_cost = np.max(hds["Price_GB"])
	min_cost = np.min(hds["Price_GB"])
	cost = ( (max_scale - min_scale) /(max_cost-min_cost))*(hds["Price_GB"]-max_cost) + max_scale
	cost = 1.0 - cost

	slider_start = .5
	data["x"] = np.arange(0,len(hds['Model']),1)
	data["y"] = slider_start * cost + slider_start * performance + slider_start * reliability
	data["Cost"] = slider_start * cost
	data["Performance"] = slider_start * performance
	data["Reliability"] = slider_start * reliability
	sizes = list(range(6, 24, 4))
	groups = pd.cut(hds["Capacity"].values, len(sizes))
	sz = [sizes[i] for i in groups.codes]
	data["Size"] = sz

	static_data = {}
	static_data["Cost"] = cost
	static_data["Performance"] = performance
	static_data["Reliability"] = reliability

	_source = ColumnDataSource(data=data)
	_static_source = ColumnDataSource(data=static_data)

	title = "Relative Hard Drive Value" 
	plot = figure(title=title, x_axis_location='below', y_axis_location='left', tools=['hover','save'])
	hover = plot.select(dict(type=HoverTool))
	hover.tooltips = [
		("Model ", "@Model"),
        ("Failure Rate ", "@Failure_Rate"),
        (hcols["Capacity"], "@Capacity"),
        (hcols["Interface"], "@Interface"),
        (hcols["RPM"], "@RPM"),
        (hcols["Cache"], "@Cache"),
        (hcols["Price_GB"], "@Price_GB{1.11}")
        ]

	p1 = plot.circle('x', 'y', source = _source, size='Size', color="Color")#, line_color="white", alpha=0.6, hover_color='white', hover_alpha=0.5)
	
	from bokeh.models import FuncTickFormatter#, FixedTickFormatter
	label_dict = {}
	for i, s in enumerate(hds["Model"]):
		label_dict[i] = s
	
	plot.y_range = Range1d(-.1, 3.1)
	plot.toolbar.logo = None
	
	plot.xaxis.visible = False
	plot.yaxis.visible = False
	from bokeh.models import SingleIntervalTicker
	ticker = SingleIntervalTicker(interval=1, num_minor_ticks=0)
	xaxis = LinearAxis(axis_label="Model", ticker=ticker)
	#yaxis = LinearAxis()#axis_label="Relative Merit")

	xaxis.formatter = FuncTickFormatter(code="""
    var labels = %s;
    return labels[tick];
	""" % label_dict)

	xaxis.major_label_orientation = -np.pi/2.7

	plot.add_layout(xaxis, 'below')
	#plot.add_layout(yaxis, 'left')

	callback1 = CustomJS(args=dict(source=_source, static_source=_static_source), code="""
	var data = source.get("data");
	var static_data = static_source.get("data");
	var f = cb_obj.value
	y = data['y']
	reli = data['Reliability']
	perf = data['Performance']
	cost = data['Cost']
	static_cost = static_data['Cost']
	for (i = 0; i < y.length; i++) {
		cost[i] = f * static_cost[i]
        y[i] = reli[i] + cost[i] + perf[i]
    }
	source.trigger('change');
	""")

	callback2 = CustomJS(args=dict(source=_source, static_source=_static_source), code="""
	var data = source.get("data");
	var static_data = static_source.get("data");
	var f = cb_obj.value
	y = data['y']
	reli = data['Reliability']
	static_reli = static_data['Reliability']
	perf = data['Performance']
	cost = data['Cost']
	for (i = 0; i < y.length; i++) {
		reli[i] = f * static_reli[i]
        y[i] = reli[i] + cost[i] + perf[i]
    }
	source.trigger('change');
	""")

	callback3 = CustomJS(args=dict(source=_source, static_source=_static_source), code="""
	var data = source.get("data");
	static_data = static_source.get("data");
	var f = cb_obj.value
	y = data['y']
	reli = data['Reliability']
	perf = data['Performance']
	static_perf = static_data['Performance']
	cost = data['Cost']
	for (i = 0; i < y.length; i++) {
		perf[i] = f*static_perf[i]
        y[i] = reli[i] + cost[i] + perf[i]
    }
	source.trigger('change');
	""")

	plot.min_border_left = 0
	plot.xaxis.axis_line_width = 2
	plot.yaxis.axis_line_width = 2
	plot.title.text_font_size = '16pt'
	plot.xaxis.axis_label_text_font_size = "14pt"
	plot.xaxis.major_label_text_font_size = "14pt"
	plot.yaxis.axis_label_text_font_size = "14pt"
	plot.yaxis.major_label_text_font_size = "14pt"
	plot.ygrid.grid_line_color = None
	plot.xgrid.grid_line_color = None
	plot.toolbar.logo = None
	plot.outline_line_width = 0
	plot.outline_line_color = "white"

	slider1 = Slider(start=0.0, end=1.0, value=slider_start, step=.05, title="Price")
	slider2 = Slider(start=0.0, end=1.0, value=slider_start, step=.05, title="Reliability")
	slider3 = Slider(start=0.0, end=1.0, value=slider_start, step=.05, title="Performance")
	slider1.js_on_change('value', callback1)
	slider2.js_on_change('value', callback2)
	slider3.js_on_change('value', callback3)
	controls = widgetbox([slider1, slider2,slider3], width=200)
	layout = row(controls, plot)
	return layout