legend_labels = ['' for _ in range(num_teams)] sc_renderers = plot_sc_data(team_objs, sc_sources, line_colors) compile_expected_wins(league_obj, team_objs, weeks, owner_to_idx, num_teams) ew_sources = get_ew_sources(weeks, team_objs, owners, week_num, num_teams) expected_wins_table = initialize_ew_table(team_objs, week_num, num_teams) table_wrap = column(children=[expected_wins_table]) ew_renderers = plot_ew_data(team_objs, ew_sources, line_colors) # register callback handlers to respond to changes in widget values lg_id_input.on_change('value', league_id_handler) lg_id_input.js_on_change('value', ga_view_callback) week_slider.on_change('value', week_slider_handler) team1_dd.on_change('value', team1_select_handler) team2_dd.on_change('value', team2_select_handler) comp_button.on_click(helper_handler) year_input.on_change('value', season_handler) # arrange layout tab1 = Panel(child=plot1_wrap, title='Scores') tab2 = Panel(child=plot2_wrap, title='Expected Wins') tab3 = Panel(child=table_wrap, title='Summary') figures = Tabs(tabs=[tab1, tab2, tab3], width=500) compare_widgets = column(team1_dd, team2_dd, comp_button)
mass_finder_header = Div(text= " <h2>Mass Finder</h2>", height=45, width=400 ) # mass_finder_range_text = Div(text= " Range mz:", width= 150, height=30 ) mass_finder_range_slider = RangeSlider(start=1.0, end=500.0, value=(1.0,50.0), title='Charge range:',name='mass_finder_range_slider', step=1, width= 250, height=30) # mass_finder_mass_text = Div(text= " Mass of Complex (kDa):", width= 150, height=30 ) mass_finder_mass = Slider(value=100, start=0.0, end=1000.0, step=10.0, title='Mass of Complex (kDa)',name='gau_sigma', width=250, height=30) mass_finder_exact_mass_text = Div(text= "Enter exact Mass (Da)", width= 150, height=30 ) mass_finder_exact_mass_sele = TextInput(value=str(mass_finder_mass.value*1000), disabled=False, width=100, height=30) mass_finder_line_text = Div(text= "Show mz prediction", width= 150, height=30 ) mass_finder_line_sele = Toggle(label='off', active=False, width=100, height=30, callback=toggle_cb) mass_finder_cb =CustomJS(args=dict(mass_finder_line_sele=mass_finder_line_sele, raw_mz=raw_mz, mass_finder_data=mass_finder_data, mass_finder_exact_mass_sele=mass_finder_exact_mass_sele, mass_finder_mass=mass_finder_mass, mass_finder_range_slider=mass_finder_range_slider, mfl=mfl), code=open(os.path.join(os.getcwd(), 'JS_Functions', "mass_finder_cb.js")).read()) mass_finder_exact_cb =CustomJS(args=dict(mass_finder_line_sele=mass_finder_line_sele, mass_finder_exact_mass_sele=mass_finder_exact_mass_sele, mass_finder_mass=mass_finder_mass), code=open(os.path.join(os.getcwd(), 'JS_Functions', "mass_finder_exact_cb.js")).read()) mass_finder_exact_mass_sele.js_on_change('value', mass_finder_exact_cb) mass_finder_column=Column(mass_finder_header,mass_finder_mass, mass_finder_range_slider, Row(mass_finder_exact_mass_text,mass_finder_exact_mass_sele), Row(mass_finder_line_text, mass_finder_line_sele), visible=False) mass_finder.js_link('active', mass_finder_column, 'visible') mass_finder_line_sele.js_link('active', mfl, 'visible') mass_finder_mass.js_on_change('value', mass_finder_cb) mass_finder_line_sele.js_on_change('active', mass_finder_cb) mass_finder_range_slider.js_on_change('value',mass_finder_cb) ### DATA PROCESSING ### cropping = Div(text= " Range mz:", width= 150, height=30 ) # crop_max = Div(text= " ", width= 150, height=30 ) gau_name = Div(text= " Gaussian Smoothing:", width= 150, height=30 ) n_smooth_name = Div(text= " Repeats of Smoothing:", width= 150, height=30 ) # bin_name = Div(text= " Bin Every:", width= 150, height=30 ) int_name = Div(text= " Intensity Threshold (%)", width= 150, height=30 )
def spectroscopy_plot(obj_id, spec_id=None): obj = Obj.query.get(obj_id) spectra = Obj.query.get(obj_id).spectra if spec_id is not None: spectra = [spec for spec in spectra if spec.id == int(spec_id)] if len(spectra) == 0: return None, None, None color_map = dict(zip([s.id for s in spectra], viridis(len(spectra)))) data = [] for i, s in enumerate(spectra): # normalize spectra to a common average flux per resolving # element of 1 (facilitates easy visual comparison) normfac = np.sum(np.gradient(s.wavelengths) * s.fluxes) / len(s.fluxes) if not (np.isfinite(normfac) and normfac > 0): # otherwise normalize the value at the median wavelength to 1 median_wave_index = np.argmin( np.abs(s.wavelengths - np.median(s.wavelengths))) normfac = s.fluxes[median_wave_index] df = pd.DataFrame({ 'wavelength': s.wavelengths, 'flux': s.fluxes / normfac, 'id': s.id, 'telescope': s.instrument.telescope.name, 'instrument': s.instrument.name, 'date_observed': s.observed_at.date().isoformat(), 'pi': s.assignment.run.pi if s.assignment is not None else "", }) data.append(df) data = pd.concat(data) dfs = [] for i, s in enumerate(spectra): # Smooth the spectrum by using a rolling average df = (pd.DataFrame({ 'wavelength': s.wavelengths, 'flux': s.fluxes }).rolling(2).mean(numeric_only=True).dropna()) dfs.append(df) smoothed_data = pd.concat(dfs) split = data.groupby('id') hover = HoverTool(tooltips=[ ('wavelength', '$x'), ('flux', '$y'), ('telesecope', '@telescope'), ('instrument', '@instrument'), ('UTC date observed', '@date_observed'), ('PI', '@pi'), ]) smoothed_max = np.max(smoothed_data['flux']) smoothed_min = np.min(smoothed_data['flux']) ymax = smoothed_max * 1.05 ymin = smoothed_min - 0.05 * (smoothed_max - smoothed_min) xmin = np.min(data['wavelength']) - 100 xmax = np.max(data['wavelength']) + 100 plot = figure( plot_width=600, plot_height=300, y_range=(ymin, ymax), x_range=(xmin, xmax), sizing_mode='scale_both', tools='box_zoom,wheel_zoom,pan,reset', active_drag='box_zoom', ) plot.add_tools(hover) model_dict = {} for i, (key, df) in enumerate(split): model_dict['s' + str(i)] = plot.line(x='wavelength', y='flux', color=color_map[key], source=ColumnDataSource(df)) plot.xaxis.axis_label = 'Wavelength (Å)' plot.yaxis.axis_label = 'Flux' plot.toolbar.logo = None # TODO how to choose a good default? plot.y_range = Range1d(0, 1.03 * data.flux.max()) toggle = CheckboxWithLegendGroup( labels=[ f'{s.instrument.telescope.nickname}/{s.instrument.name} ({s.observed_at.date().isoformat()})' for s in spectra ], active=list(range(len(spectra))), colors=[color_map[k] for k, df in split], ) toggle.callback = CustomJS( args={ 'toggle': toggle, **model_dict }, code=""" for (let i = 0; i < toggle.labels.length; i++) { eval("s" + i).visible = (toggle.active.includes(i)) } """, ) z_title = Div(text="Redshift (<i>z</i>): ") z_slider = Slider( value=obj.redshift if obj.redshift is not None else 0.0, start=0.0, end=1.0, step=0.001, show_value=False, format="0[.]000", ) z_textinput = TextInput( value=str(obj.redshift if obj.redshift is not None else 0.0)) z_slider.callback = CustomJS( args={ 'slider': z_slider, 'textinput': z_textinput }, code=""" textinput.value = slider.value.toFixed(3).toString(); textinput.change.emit(); """, ) z = column(z_title, z_slider, z_textinput) v_title = Div(text="<i>V</i><sub>expansion</sub> (km/s): ") v_exp_slider = Slider( value=0.0, start=0.0, end=3e4, step=10.0, show_value=False, ) v_exp_textinput = TextInput(value='0') v_exp_slider.callback = CustomJS( args={ 'slider': v_exp_slider, 'textinput': v_exp_textinput }, code=""" textinput.value = slider.value.toFixed(0).toString(); textinput.change.emit(); """, ) v_exp = column(v_title, v_exp_slider, v_exp_textinput) for i, (wavelengths, color) in enumerate(SPEC_LINES.values()): el_data = pd.DataFrame({'wavelength': wavelengths}) obj_redshift = 0 if obj.redshift is None else obj.redshift el_data['x'] = el_data['wavelength'] * (1.0 + obj_redshift) model_dict[f'el{i}'] = plot.segment( x0='x', x1='x', # TODO change limits y0=0, y1=1e-13, color=color, source=ColumnDataSource(el_data), ) model_dict[f'el{i}'].visible = False # Split spectral lines into 3 columns element_dicts = np.array_split( np.array(list(SPEC_LINES.items()), dtype=object), 3) elements_groups = [] col_offset = 0 for element_dict in element_dicts: labels = [key for key, value in element_dict] colors = [c for key, (w, c) in element_dict] elements = CheckboxWithLegendGroup( labels=labels, active=[], colors=colors, ) elements_groups.append(elements) # TODO callback policy: don't require submit for text changes? elements.callback = CustomJS( args={ 'elements': elements, 'z': z_textinput, 'v_exp': v_exp_textinput, **model_dict, }, code=f""" let c = 299792.458; // speed of light in km / s const i_max = {col_offset} + elements.labels.length; let local_i = 0; for (let i = {col_offset}; i < i_max; i++) {{ let el = eval("el" + i); el.visible = (elements.active.includes(local_i)) el.data_source.data.x = el.data_source.data.wavelength.map( x_i => (x_i * (1 + parseFloat(z.value)) / (1 + parseFloat(v_exp.value) / c)) ); el.data_source.change.emit(); local_i++; }} """, ) col_offset += len(labels) # Our current version of Bokeh doesn't properly execute multiple callbacks # https://github.com/bokeh/bokeh/issues/6508 # Workaround is to manually put the code snippets together z_textinput.js_on_change( 'value', CustomJS( args={ 'elements0': elements_groups[0], 'elements1': elements_groups[1], 'elements2': elements_groups[2], 'z': z_textinput, 'slider': z_slider, 'v_exp': v_exp_textinput, **model_dict, }, code=""" // Update slider value to match text input slider.value = parseFloat(z.value).toFixed(3); // Update plot data for each element let c = 299792.458; // speed of light in km / s const offset_col_1 = elements0.labels.length; const offset_col_2 = offset_col_1 + elements1.labels.length; const i_max = offset_col_2 + elements2.labels.length; for (let i = 0; i < i_max; i++) {{ let el = eval("el" + i); el.visible = elements0.active.includes(i) || elements1.active.includes(i - offset_col_1) || elements2.active.includes(i - offset_col_2); el.data_source.data.x = el.data_source.data.wavelength.map( x_i => (x_i * (1 + parseFloat(z.value)) / (1 + parseFloat(v_exp.value) / c)) ); el.data_source.change.emit(); }} """, ), ) v_exp_textinput.js_on_change( 'value', CustomJS( args={ 'elements0': elements_groups[0], 'elements1': elements_groups[1], 'elements2': elements_groups[2], 'z': z_textinput, 'slider': v_exp_slider, 'v_exp': v_exp_textinput, **model_dict, }, code=""" // Update slider value to match text input slider.value = parseFloat(v_exp.value).toFixed(3); // Update plot data for each element let c = 299792.458; // speed of light in km / s const offset_col_1 = elements0.labels.length; const offset_col_2 = offset_col_1 + elements1.labels.length; const i_max = offset_col_2 + elements2.labels.length; for (let i = 0; i < i_max; i++) {{ let el = eval("el" + i); el.visible = elements0.active.includes(i) || elements1.active.includes(i - offset_col_1) || elements2.active.includes(i - offset_col_2); el.data_source.data.x = el.data_source.data.wavelength.map( x_i => (x_i * (1 + parseFloat(z.value)) / (1 + parseFloat(v_exp.value) / c)) ); el.data_source.change.emit(); }} """, ), ) row1 = row(plot, toggle) row2 = row(elements_groups) row3 = row(z, v_exp) layout = column(row1, row2, row3) return _plot_to_json(layout)
def compare(): # if proj among arguments, show this tree first. try: proj_a = int(request.args.get('proj_a', None)) proj_b = int(request.args.get('proj_b', None)) except (TypeError, ValueError): proj_a = proj_b = None include = request.args.get('include', None) # list of projects (and proj_names) used to create dropdown project selector upload_list = UserFile.query.filter_by(user_id=current_user.id).\ filter_by(run_complete=True).order_by(UserFile.file_id).all() if len(upload_list) > 1: # Use specified project from args or highest file_id as CURRENT PROJECT current_proj = upload_list[-1] # override if valid proj specified if proj_a and proj_b: current_temp_a = [u for u in upload_list if u.file_id == proj_a] current_temp_b = [u for u in upload_list if u.file_id == proj_b] # if not among user's finished projects, use highest file_id if len(current_temp_a) == 1 and len(current_temp_b) == 1: current_proj_a = current_temp_a[0] current_proj_b = current_temp_b[0] else: current_proj_a = upload_list[-2] current_proj_b = upload_list[-1] else: current_proj_a = upload_list[-2] current_proj_b = upload_list[-1] detail_path1 = naming_rules.get_detailed_path(current_proj_a) detail_path2 = naming_rules.get_detailed_path(current_proj_b) js_name1 = naming_rules.get_js_name(current_proj_a) js_name2 = naming_rules.get_js_name(current_proj_b) xlabel = u"Effect size ({})".format(current_proj_a.get_fancy_filename()) ylabel = u"Effect size ({})".format(current_proj_b.get_fancy_filename()) # load pathways with 1+ mutation in 1+ patients, # ignoring ones with 'cancer' etc in name all_paths1 = load_pathway_list_from_file(detail_path1) all_paths2 = load_pathway_list_from_file(detail_path2) # IDs with p<0.05 and +ve effect sig_p = OrderedDict( [(i.path_id, i.nice_name) for i in all_paths1 if i.gene_set]) sig_pids1 = [i for i in sig_p] sig_p2 = OrderedDict( [(i.path_id, i.nice_name) for i in all_paths2 if i.gene_set]) sig_pids2 = [i for i in sig_p2] sig_p.update(sig_p2) # ORDERED by proj1 effect size # BUILD DATAFRAME WITH ALL sig PATHWAYS, proj1 object order. pway_names = sig_p.values() # order important columns = ['path_id', 'pname', 'ind1', 'ind2', 'e1', 'e2', 'e1_only', 'e2_only', 'q1', 'q2'] df = pd.DataFrame(index=sig_p.keys(), data={'pname': pway_names}, columns=columns) for path_group, evar, qvar, ind, sigs in \ [(all_paths1, 'e1', 'q1', 'ind1', sig_pids1), (all_paths2, 'e2', 'q2', 'ind2', sig_pids2)]: for path in path_group: path_id = path.path_id if path_id not in sig_p: continue df.loc[path_id, evar] = get_effect(path) df.loc[path_id, qvar] = get_q(path) temp_ind = sigs.index(path_id) if path_id in sigs else -1 df.loc[path_id, ind] = temp_ind df.ind1.fillna(-1, inplace=True) df.ind2.fillna(-1, inplace=True) df.e1_only = df.where(df.e2.isnull())['e1'] df.e2_only = df.where(df.e1.isnull())['e2'] inds1 = list(df.ind1) inds2 = list(df.ind2) source = ColumnDataSource(data=df) source_full = ColumnDataSource(data=df) source.name, source_full.name = 'data_visible', 'data_full' # SET UP FIGURE minx = df.e1.min() minx *= 1 - minx / abs(minx) * 0.2 miny = df.e2.min() miny *= 1 - miny/abs(miny) * 0.2 maxx = df.e1.max() * 1.2 maxy = df.e2.max() * 1.2 TOOLS = "lasso_select,box_select,hover,crosshair,pan,wheel_zoom,"\ "box_zoom,reset,tap,help" # poly_select,lasso_select, previewsave # SUPLOTS p = figure(plot_width=DIM_COMP_W, plot_height=DIM_COMP_H, tools=TOOLS, title=None, logo=None, toolbar_location="above", x_range=Range1d(minx, maxx), y_range=Range1d(miny, maxy), x_axis_type="log", y_axis_type="log" ) pb = figure(plot_width=DIM_COMP_SM, plot_height=DIM_COMP_H, tools=TOOLS, y_range=p.y_range, x_axis_type="log", y_axis_type="log") pa = figure(plot_width=DIM_COMP_W, plot_height=DIM_COMP_SM, tools=TOOLS, x_range=p.x_range, x_axis_type="log", y_axis_type="log") pp = figure(plot_width=DIM_COMP_SM, plot_height=DIM_COMP_SM, tools=TOOLS, outline_line_color=None) # SPANS p.add_layout(plot_fns.get_span(1, 'height')) p.add_layout(plot_fns.get_span(1, 'width')) pa.add_layout(plot_fns.get_span(1, 'height')) pb.add_layout(plot_fns.get_span(1, 'width')) # STYLE for ax in [p, pa, pb]: ax.grid.visible = False ax.outline_line_width = 2 ax.background_fill_color = 'whitesmoke' for ax in [pa, pb]: ax.xaxis.visible = False ax.yaxis.visible = False pa.title.text = xlabel pb.title.text = ylabel pa.title_location, pa.title.align = 'below', 'center' pb.title_location, pb.title.align = 'left', 'center' # WIDGETS q_input = TextInput(value='', title="P* cutoff", placeholder='e.g. 0.05') gene_input = TextInput(value='', title="Gene list", placeholder='e.g. TP53,BRAF') radio_include = RadioGroup(labels=["Include", "Exclude"], active=0) widgets = widgetbox(q_input, gene_input, radio_include, width=200, css_classes=['widgets_sg']) grid = gridplot([[pb, p, widgets], [Spacer(width=DIM_COMP_SM), pa, Spacer()]], sizing_mode='fixed') cb_inclusion = CustomJS(args=dict(genes=gene_input), code=""" var gene_str = genes.value if (!gene_str) return; var include = cb_obj.active == 0 ? true : false selectPathwaysByGenes(gene_str, include); """) cb_genes = CustomJS(args=dict(radio=radio_include), code=""" var gene_str = cb_obj.value if (!gene_str) return; var include = radio.active == 0 ? true : false selectPathwaysByGenes(gene_str, include); """) radio_include.js_on_change('active', cb_inclusion) gene_input.js_on_change('value', cb_genes) # SCATTER p.circle("e1", "e2", source=source, **SCATTER_KW) pa.circle('e1_only', 1, source=source, **SCATTER_KW) pb.circle(1, 'e2_only', source=source, **SCATTER_KW) # HOVER for hover in grid.select(dict(type=HoverTool)): hover.tooltips = OrderedDict([ ("name", "@pname"), ("effects", "(@e1, @e2)"), ("P*", ("(@q1, @q2)")) ]) # ADD Q FILTERING CALLBACK callback = CustomJS(args=dict(source=source, full=source_full), code=""" // get old selection indices, if any var prv_selected = source.selected['1d'].indices; var prv_select_full = [] for(var i=0; i<prv_selected.length; i++){ prv_select_full.push(scatter_array[prv_selected[i]]) } var new_selected = [] var q_val = cb_obj.value; if(q_val == '') q_val = 1 var fullset = full.data; var n_total = fullset['e1'].length; // Convert float64arrays to array var col_names = %s ; col_names.forEach(function(col_name){ source.data[col_name] = [].slice.call(source.data[col_name]) source.data[col_name].length = 0 }) scatter_array.length = 0; var j = -1; // new glyph indices for (i = 0; i < n_total; i++) { this_q1 = fullset['q1'][i]; this_q2 = fullset['q2'][i]; if(this_q1 <= q_val || this_q2 <= q_val){ j++; // preserve previous selection if still visible col_names.forEach(function(col){ source.data[col].push(fullset[col][i]); }) scatter_array.push(i) if($.inArray(i, prv_select_full) > -1){ new_selected.push(j); } } } source.selected['1d'].indices = new_selected; source.trigger('change'); updateIfSelectionChange_afterWait(); """ % columns) q_input.js_on_change('value', callback) script, div = plot_fns.get_bokeh_components(grid) proj_dir_a = naming_rules.get_project_folder(current_proj_a) proj_dir_b = naming_rules.get_project_folder(current_proj_b) if os.path.exists(os.path.join(proj_dir_a, 'matrix_svg_cnv')) and \ os.path.exists(os.path.join(proj_dir_b, 'matrix_svg_cnv')): has_cnv = True else: has_cnv = False else: # not enough projects yet! flash("Two completed projects are required for a comparison.", "warning") return redirect(url_for('.index')) return render_template('pway/compare.html', current_projs=[current_proj_a, current_proj_b], inds_use=[inds1, inds2], has_cnv=has_cnv, js_name_a=js_name1, js_name_b=js_name2, projects=upload_list, bokeh_script=script, bokeh_div=div, include_genes=include, resources=plot_fns.resources)
def volcano(data, folder='', tohighlight=None, tooltips=[('gene', '@gene_id')], title="volcano plot", xlabel='log-fold change', ylabel='-log(Q)', maxvalue=100, searchbox=False, logfoldtohighlight=0.15, pvaltohighlight=0.1, showlabels=False): """ Make an interactive volcano plot from Differential Expression analysis tools outputs Args: ----- data: a df with rows genes and cols [log2FoldChange, pvalue, gene_id] folder: str of location where to save the plot, won't save if empty tohighlight: list[str] of genes to highlight in the plot tooltips: list[tuples(str,str)] if user wants tot specify another bokeh tooltip title: str plot title xlabel: str if user wants to specify the title of the x axis ylabel: str if user wants tot specify the title of the y axis maxvalue: float the max -log2(pvalue authorized usefull when managing inf vals) searchbox: bool whether or not to add a searchBox to interactively highlight genes logfoldtohighlight: float min logfoldchange when to diplay points pvaltohighlight: float min pvalue when to diplay points showlabels: bool whether or not to show a text above each datapoint with its label information Returns: -------- The bokeh object """ # pdb.set_trace() to_plot_not, to_plot_yes = selector( data, tohighlight if tohighlight is not None else [], logfoldtohighlight, pvaltohighlight) hover = bokeh.models.HoverTool(tooltips=tooltips, names=['circles']) # Create figure p = bokeh.plotting.figure(title=title, plot_width=650, plot_height=450) p.xgrid.grid_line_color = 'white' p.ygrid.grid_line_color = 'white' p.xaxis.axis_label = xlabel p.yaxis.axis_label = ylabel # Add the hover tool p.add_tools(hover) p, source1 = add_points(p, to_plot_not, 'log2FoldChange', 'pvalue', color='#1a9641', maxvalue=maxvalue) p, source2 = add_points(p, to_plot_yes, 'log2FoldChange', 'pvalue', color='#fc8d59', alpha=0.6, outline=True, maxvalue=maxvalue) if showlabels: labels = LabelSet(x='log2FoldChange', y='transformed_q', text_font_size='7pt', text="gene_id", level="glyph", x_offset=5, y_offset=5, source=source2, render_mode='canvas') p.add_layout(labels) if searchbox: text = TextInput(title="text", value="gene") text.js_on_change( 'value', CustomJS(args=dict(source=source1), code=""" var data = source.data var value = cb_obj.value var gene_id = data.gene_id var a = -1 for (i=0; i < gene_id.length; i++) { if ( gene_id[i]===value ) { a=i; console.log(i); data.size[i]=7; data.alpha[i]=1; data.color[i]='#fc8d59' } } source.data = data console.log(source) console.log(cb_obj) source.change.emit() console.log(source) """)) p = column(text, p) p.output_backend = "svg" if folder: save(p, folder + title.replace(' ', "_") + "_volcano.html") export_svg(p, filename=folder + title.replace(' ', "_") + "_volcano.svg") try: show(p) except: show(p) return p
def spectroscopy_plot(obj_id, user, spec_id=None, width=600, height=300): obj = Obj.query.get(obj_id) spectra = ( DBSession().query(Spectrum).join(Obj).join(GroupSpectrum).filter( Spectrum.obj_id == obj_id, GroupSpectrum.group_id.in_([g.id for g in user.accessible_groups]), )).all() if spec_id is not None: spectra = [spec for spec in spectra if spec.id == int(spec_id)] if len(spectra) == 0: return None, None, None rainbow = cm.get_cmap('rainbow', len(spectra)) palette = list(map(rgb2hex, rainbow(range(len(spectra))))) color_map = dict(zip([s.id for s in spectra], palette)) data = [] for i, s in enumerate(spectra): # normalize spectra to a median flux of 1 for easy comparison normfac = np.nanmedian(s.fluxes) df = pd.DataFrame({ 'wavelength': s.wavelengths, 'flux': s.fluxes / normfac, 'id': s.id, 'telescope': s.instrument.telescope.name, 'instrument': s.instrument.name, 'date_observed': s.observed_at.date().isoformat(), 'pi': (s.assignment.run.pi if s.assignment is not None else (s.followup_request.allocation.pi if s.followup_request is not None else "")), }) data.append(df) data = pd.concat(data) dfs = [] for i, s in enumerate(spectra): # Smooth the spectrum by using a rolling average df = (pd.DataFrame({ 'wavelength': s.wavelengths, 'flux': s.fluxes }).rolling(2).mean(numeric_only=True).dropna()) dfs.append(df) smoothed_data = pd.concat(dfs) split = data.groupby('id') hover = HoverTool(tooltips=[ ('wavelength', '$x'), ('flux', '$y'), ('telesecope', '@telescope'), ('instrument', '@instrument'), ('UTC date observed', '@date_observed'), ('PI', '@pi'), ]) smoothed_max = np.max(smoothed_data['flux']) smoothed_min = np.min(smoothed_data['flux']) ymax = smoothed_max * 1.05 ymin = smoothed_min - 0.05 * (smoothed_max - smoothed_min) xmin = np.min(data['wavelength']) - 100 xmax = np.max(data['wavelength']) + 100 plot = figure( aspect_ratio=2, sizing_mode='scale_width', y_range=(ymin, ymax), x_range=(xmin, xmax), tools='box_zoom,wheel_zoom,pan,reset', active_drag='box_zoom', ) plot.add_tools(hover) model_dict = {} for i, (key, df) in enumerate(split): model_dict['s' + str(i)] = plot.step( x='wavelength', y='flux', color=color_map[key], source=ColumnDataSource(df), mode="center", ) plot.xaxis.axis_label = 'Wavelength (Å)' plot.yaxis.axis_label = 'Flux' plot.toolbar.logo = None # TODO how to choose a good default? plot.y_range = Range1d(0, 1.03 * data.flux.max()) spec_labels = [] for k, _ in split: s = Spectrum.query.get(k) label = f'{s.instrument.telescope.nickname}/{s.instrument.name} ({s.observed_at.date().isoformat()})' spec_labels.append(label) toggle = CheckboxWithLegendGroup( labels=spec_labels, active=list(range(len(spectra))), colors=[color_map[k] for k, df in split], width=width // 5, ) toggle.js_on_click( CustomJS( args={ 'toggle': toggle, **model_dict }, code=""" for (let i = 0; i < toggle.labels.length; i++) { eval("s" + i).visible = (toggle.active.includes(i)) } """, ), ) z_title = Div(text="Redshift (<i>z</i>): ") z_slider = Slider( value=obj.redshift if obj.redshift is not None else 0.0, start=0.0, end=1.0, step=0.001, show_value=False, format="0[.]000", ) z_textinput = TextInput( value=str(obj.redshift if obj.redshift is not None else 0.0)) z_slider.js_on_change( 'value', CustomJS( args={ 'slider': z_slider, 'textinput': z_textinput }, code=""" textinput.value = parseFloat(slider.value).toFixed(3); textinput.change.emit(); """, ), ) z = column(z_title, z_slider, z_textinput) v_title = Div(text="<i>V</i><sub>expansion</sub> (km/s): ") v_exp_slider = Slider( value=0.0, start=0.0, end=3e4, step=10.0, show_value=False, ) v_exp_textinput = TextInput(value='0') v_exp_slider.js_on_change( 'value', CustomJS( args={ 'slider': v_exp_slider, 'textinput': v_exp_textinput }, code=""" textinput.value = parseFloat(slider.value).toFixed(0); textinput.change.emit(); """, ), ) v_exp = column(v_title, v_exp_slider, v_exp_textinput) for i, (wavelengths, color) in enumerate(SPEC_LINES.values()): el_data = pd.DataFrame({'wavelength': wavelengths}) obj_redshift = 0 if obj.redshift is None else obj.redshift el_data['x'] = el_data['wavelength'] * (1.0 + obj_redshift) model_dict[f'el{i}'] = plot.segment( x0='x', x1='x', # TODO change limits y0=0, y1=1e4, color=color, source=ColumnDataSource(el_data), ) model_dict[f'el{i}'].visible = False # Split spectral line legend into columns columns = 7 element_dicts = zip(*itertools.zip_longest(*[iter(SPEC_LINES.items())] * columns)) elements_groups = [] # The Bokeh checkbox groups callbacks = [] # The checkbox callbacks for each element for column_idx, element_dict in enumerate(element_dicts): element_dict = [e for e in element_dict if e is not None] labels = [key for key, value in element_dict] colors = [c for key, (w, c) in element_dict] elements = CheckboxWithLegendGroup(labels=labels, active=[], colors=colors, width=width // (columns + 1)) elements_groups.append(elements) callback = CustomJS( args={ 'elements': elements, 'z': z_textinput, 'v_exp': v_exp_textinput, **model_dict, }, code=f""" let c = 299792.458; // speed of light in km / s const i_max = {column_idx} + {columns} * elements.labels.length; let local_i = 0; for (let i = {column_idx}; i < i_max; i = i + {columns}) {{ let el = eval("el" + i); el.visible = (elements.active.includes(local_i)) el.data_source.data.x = el.data_source.data.wavelength.map( x_i => (x_i * (1 + parseFloat(z.value)) / (1 + parseFloat(v_exp.value) / c)) ); el.data_source.change.emit(); local_i++; }} """, ) elements.js_on_click(callback) callbacks.append(callback) z_textinput.js_on_change( 'value', CustomJS( args={ 'z': z_textinput, 'slider': z_slider, 'v_exp': v_exp_textinput, **model_dict, }, code=""" // Update slider value to match text input slider.value = parseFloat(z.value).toFixed(3); """, ), ) v_exp_textinput.js_on_change( 'value', CustomJS( args={ 'z': z_textinput, 'slider': v_exp_slider, 'v_exp': v_exp_textinput, **model_dict, }, code=""" // Update slider value to match text input slider.value = parseFloat(v_exp.value).toFixed(3); """, ), ) # Update the element spectral lines as well for callback in callbacks: z_textinput.js_on_change('value', callback) v_exp_textinput.js_on_change('value', callback) row1 = row(plot, toggle) row2 = row(elements_groups) row3 = row(z, v_exp) layout = column(row1, row2, row3, width=width) return bokeh_embed.json_item(layout)
def bokeh_plot(import_df): import pandas as pd import numpy as np from bokeh.plotting import figure, show from bokeh.layouts import layout, widgetbox, row, column, gridplot from bokeh.models import ColumnDataSource, HoverTool, BoxZoomTool, ResetTool, PanTool, CustomJS, PrintfTickFormatter, WheelZoomTool, SaveTool, LassoSelectTool, NumeralTickFormatter from bokeh.models.widgets import Slider, Select, TextInput, Div, Tabs, Panel, DataTable, DateFormatter, TableColumn, PreText, NumberFormatter, RangeSlider from bokeh.io import curdoc from functools import lru_cache from bokeh.transform import dodge from os.path import dirname, join from bokeh.core.properties import value #load plotting data here @lru_cache() def load_data(): df = import_df df.dropna(how='all', axis=0) #Northest=['3229','3277','3276','3230','3259','All_Stores_NE'] df.location_reference_id = df.location_reference_id.astype(str) #df['region'] = ['Northeast' if x in Northest else 'Midwest' for x in df['location_reference_id']] df['date'] = pd.to_datetime(df['date']) df[[ 'BOH_gt_Shelf_Capacity', 'OTL_gt_Shelf_Capacity', 'Ideal_BOH_gt_Shelf_Capacity', 'BOH_lt_Ideal', 'BOH_eq_Ideal', 'BOH_gt_Ideal', 'Demand_Fulfilled', 'Fill_Rate', 'Backroom_OH', 'Total_OH', 'Prop_OH_in_Backroom', 'Never_Q98_gt_POG', 'Never_Ideal_BOH_gt_POG', 'Sometimes_OTL_Casepack_1_gt_POG', 'Always_OTL_Casepack_1_le_POG', 'Non_POG' ]] = df[[ 'BOH > Shelf Capacity', 'OTL > Shelf Capacity', 'Ideal BOH > Shelf Capacity', 'BOH < Ideal', 'BOH = Ideal', 'BOH > Ideal', 'Demand Fulfilled', 'Fill Rate', 'Backroom_OH', 'Total OH', 'Prop OH in Backroom', 'Never: Q98 > POG', 'Never: Ideal BOH > POG', 'Sometimes: OTL+Casepack-1 > POG', 'Always: OTL+Casepack-1 <= POG', 'Non-POG' ]] df['date_bar'] = df['date'] df['date_bar'] = df['date_bar'].astype(str) return df #Filter data source for "All" stores OR data agrregation on DC level df_agg = load_data().groupby(['location_reference_id'], as_index=False).sum() source1 = ColumnDataSource(data=df_agg) sdate = min(load_data()['date']) edate = max(load_data()['date']) nodes = len(list(load_data().location_reference_id.unique())) days = len(list(load_data().date.unique())) policy = "Prod" #list of dates for vbar charts x_range_list = list(load_data().date_bar.unique()) #direct access to number of location_reference_idand region all_locations1 = list(load_data().location_reference_id.unique()) #agg_value=['All'] #all location_reference_idfrom csv file along with an option for agg data "All" #all_locations=all_locations1+agg_value #all_regions = ['Northeast', 'Midwest'] all_regions = list(load_data().region.unique()) desc = Div(text="All locations", width=230) pre = Div(text="_", width=230) location = Select(title="Location", options=all_locations1, value="All_Stores_NE") region = Select(title="Region", options=all_regions, value="NE") text_input = TextInput(value="default", title="Search Location:") #full data set from load_data(df=df_import) source = ColumnDataSource(data=load_data()) original_source = ColumnDataSource(data=load_data()) #plotting starts........... here are total 8 graphs for each Metric. #Back room on hand hover = HoverTool( tooltips=[("Location", "@location_reference_id"), ( "Date", "@date_bar"), ("Backroom_OH", "@Backroom_OH{0,0.00}")]) TOOLS = [ hover, BoxZoomTool(), LassoSelectTool(), WheelZoomTool(), PanTool(), ResetTool(), SaveTool() ] p = figure(x_range=x_range_list, plot_width=1000, plot_height=525, title="Backroom On hand by store", tools=TOOLS, toolbar_location='above', x_axis_label="Date", y_axis_label="Backroom OH") p.background_fill_color = "#e6e9ed" p.background_fill_alpha = 0.5 p.vbar(x=dodge('date_bar', -0.25, range=p.x_range), top='Backroom_OH', hover_alpha=0.5, hover_line_color='black', width=0.8, source=source, color="#718dbf") p.xaxis.major_label_orientation = 1 p.legend.border_line_width = 3 p.legend.border_line_color = None p.legend.border_line_alpha = 0.5 p.title.text_color = "olive" #inbound outbound hover_m = HoverTool( tooltips=[("Location", "@location_reference_id"), ( "Date", "@date_bar"), ( "Inbound", "@Inbound{0,0.00}"), ("Outbound", "@Outbound{0,0.00}")]) TOOLS_m = [ hover_m, BoxZoomTool(), LassoSelectTool(), WheelZoomTool(), PanTool(), ResetTool(), SaveTool() ] m = figure(plot_height=525, plot_width=1000, x_range=x_range_list, title="Inbound/Outbound by store", tools=TOOLS_m, toolbar_location='above', x_axis_label="Date", y_axis_label="Units") m.background_fill_color = "#e6e9ed" m.background_fill_alpha = 0.5 m.vbar(x=dodge('date_bar', -0.25, range=m.x_range), top='Inbound', hover_alpha=0.5, hover_line_color='black', width=0.4, source=source, color="#718dbf", legend=value("Inbound")) m.vbar(x=dodge('date_bar', 0.25, range=m.x_range), top='Outbound', hover_alpha=0.5, hover_line_color='black', width=0.4, source=source, color="#e84d60", legend=value("Outbound")) m.xaxis.major_label_orientation = 1 m.legend.border_line_width = 3 m.legend.border_line_color = None m.legend.border_line_alpha = 0.5 m.title.text_color = "olive" #Stockout hover_s = HoverTool( tooltips=[("Location", "@location_reference_id"), ( "Date", "@date_bar"), ( "BOH_OOS", "@BOH_OOS{0,0.000}"), ("EOH_OOS", "@EOH_OOS{0,0.000}")]) TOOLS_s = [ hover_s, BoxZoomTool(), LassoSelectTool(), WheelZoomTool(), PanTool(), ResetTool(), SaveTool() ] s = figure(plot_height=525, plot_width=1000, title="Stockouts by store", x_axis_type="datetime", toolbar_location='above', tools=TOOLS_s, x_axis_label="Date", y_axis_label="Prop Stockout") s.background_fill_color = "#e6e9ed" s.background_fill_alpha = 0.5 s.circle(x='date', y='EOH_OOS', source=source, fill_color=None, line_color="#4375c6") s.line(x='date', y='EOH_OOS', source=source, hover_alpha=0.5, hover_line_color='black', line_width=2, line_color='navy', legend=value("EOH OOS")) s.circle(x='date', y='BOH_OOS', source=source, fill_color=None, line_color="#4375c6") s.line(x='date', y='BOH_OOS', source=source, hover_alpha=0.5, hover_line_color='black', line_width=2, line_color='red', legend=value("BOH OOS")) s.legend.border_line_width = 3 s.legend.border_line_color = None s.legend.border_line_alpha = 0.5 s.title.text_color = "olive" #Fill rate hover_t = HoverTool( tooltips=[("Location", "@location_reference_id"), ( "Date", "@date_bar"), ("Fill Rate", "@Fill_Rate{0,0.00}")]) TOOLS_t = [ hover_t, BoxZoomTool(), LassoSelectTool(), WheelZoomTool(), PanTool(), ResetTool(), SaveTool() ] t = figure(plot_height=525, x_range=x_range_list, plot_width=1000, title="Fill rates by store", tools=TOOLS_t, toolbar_location='above', x_axis_label="Date", y_axis_label="Fill rate") t.background_fill_color = "#e6e9ed" t.background_fill_alpha = 0.5 t.vbar(x=dodge('date_bar', -0.25, range=t.x_range), top='Fill Rate', hover_alpha=0.5, hover_line_color='black', width=0.8, source=source, color="#718dbf") t.xaxis.major_label_orientation = 1 t.legend.border_line_width = 3 t.legend.border_line_color = None t.legend.border_line_alpha = 0.5 t.title.text_color = "olive" # % Backroom spillover hover_w = HoverTool( tooltips=[("Location", "@location_reference_id"), ("Date", "@date_bar"), ("Prop OH in Backroom", "@Prop_OH_in_Backroom{0,0.00}")]) TOOLS_w = [ hover_w, BoxZoomTool(), LassoSelectTool(), WheelZoomTool(), PanTool(), ResetTool(), SaveTool() ] w = figure(plot_height=525, plot_width=1000, title="Prop OH in Backroom by store", x_axis_type="datetime", tools=TOOLS_w, toolbar_location='above', x_axis_label="Date", y_axis_label=" % Backroom spillover") w.background_fill_color = "#e6e9ed" w.background_fill_alpha = 0.5 w.circle(x='date', y='Prop OH in Backroom', source=source, fill_color=None, line_color="#4375c6") w.line(x='date', y='Prop OH in Backroom', source=source, hover_alpha=0.5, hover_line_color='black', line_width=2, line_color='navy') w.title.text_font_style = "bold" w.title.text_color = "olive" w.legend.click_policy = "hide" w.yaxis[0].formatter = NumeralTickFormatter(format="0.0%") #BOH vs Ideal hover_f = HoverTool( tooltips=[("Location", "@location_reference_id"), ( "Date", "@date_bar"), ('BOH < Ideal', "@BOH_lt_Ideal{0,0.00}" ), ('BOH > Ideal', "@BOH_gt_Ideal{0,0.00}" ), ('BOH = Ideal', "@BOH_eq_Ideal{0,0.00}")]) TOOLS_f = [ hover_f, BoxZoomTool(), LassoSelectTool(), WheelZoomTool(), PanTool(), ResetTool(), SaveTool() ] colors = ["#c9d9d3", "#718dbf", "#e84d60"] BOH_vs_ideal = ['BOH < Ideal', 'BOH > Ideal', 'BOH = Ideal'] f = figure(x_range=x_range_list, plot_height=525, plot_width=1000, title="BOH vs Ideal by store", toolbar_location='above', x_axis_label="Date", y_axis_label="Prop", tools=TOOLS_f) f.vbar_stack(BOH_vs_ideal, x='date_bar', width=0.9, color=colors, source=source, legend=[value(x) for x in BOH_vs_ideal], name=BOH_vs_ideal) f.xaxis.major_label_orientation = 1 f.legend.border_line_width = 3 f.legend.border_line_color = None f.legend.border_line_alpha = 0.5 f.title.text_color = "olive" #Pog Fit hover_g = HoverTool( tooltips=[("Location", "@location_reference_id"), ( "Date", "@date_bar"), ('Never: Q98 > POG', "@Never_Q98_gt_POG{0,0.00}"), ("Never: Ideal BOH > POG", "@Never_Ideal_BOH_gt_POG{0,0.00}"), ("Sometimes: OTL+Casepack-1 > POG", "@Sometimes_OTL_Casepack_1_gt_POG{0,0.00}"), ("Always: OTL+Casepack-1 <= POG", "@Always_OTL_Casepack_1_le_POG{0,0.00}" ), ("Non-POG'", "@Non_POG{0,0.00}")]) TOOLS_g = [ hover_g, BoxZoomTool(), LassoSelectTool(), WheelZoomTool(), PanTool(), ResetTool(), SaveTool() ] colors2 = ['#79D151', "#718dbf", '#29788E', '#fc8d59', '#d53e4f'] pog_fit = [ 'Never: Q98 > POG', 'Never: Ideal BOH > POG', 'Sometimes: OTL+Casepack-1 > POG', 'Always: OTL+Casepack-1 <= POG', 'Non-POG' ] g = figure(x_range=x_range_list, plot_height=525, plot_width=1200, title="Pog Fit by store", toolbar_location='above', x_axis_label="Date", y_axis_label="Counts", tools=TOOLS_g) g.vbar_stack(pog_fit, x='date_bar', width=0.9, color=colors2, source=source, legend=[value(x) for x in pog_fit], name=pog_fit) g.xaxis.major_label_orientation = 1 g.legend.border_line_width = 3 g.legend.border_line_color = None g.legend.border_line_alpha = 0.5 g.title.text_color = "olive" g.legend.location = "top_right" # BOH vs Pog colors3 = ["#c9d9d3", "#718dbf", "#e84d60"] shelf = [ 'BOH > Shelf Capacity', 'OTL > Shelf Capacity', 'Ideal BOH > Shelf Capacity' ] hover_h = HoverTool( tooltips=[("Location", "@location_reference_id"), ("Date", "@date_bar"), ("OTL > Shelf Capacity", "@OTL_gt_Shelf_Capacity{0,0.00}" ), ("BOH > Shelf Capacity", "@BOH_gt_Shelf_Capacity{0,0.00}"), ("Ideal BOH > Shelf Capacity", "@Ideal_BOH_gt_Shelf_Capacity{0,0.00}")]) TOOLS_h = [ hover_h, BoxZoomTool(), LassoSelectTool(), WheelZoomTool(), PanTool(), ResetTool(), SaveTool() ] h = figure(plot_height=525, plot_width=1000, title="BOH vs Pog by store", x_axis_type="datetime", toolbar_location='above', tools=TOOLS_h, x_axis_label="Date", y_axis_label="Prop") h.background_fill_color = "#e6e9ed" h.background_fill_alpha = 0.5 h.circle(x='date', y='BOH > Shelf Capacity', source=source, fill_color=None, line_color="#4375c6") h.line(x='date', y='BOH > Shelf Capacity', source=source, hover_alpha=0.5, hover_line_color='black', line_width=2, line_color='navy', legend=value("BOH > Shelf Capacity")) h.circle(x='date', y='OTL > Shelf Capacity', source=source, fill_color=None, line_color="#4375c6") h.line(x='date', y='OTL > Shelf Capacity', source=source, hover_alpha=0.5, hover_line_color='black', line_width=2, line_color="green", legend=value("OTL > Shelf Capacity")) h.circle(x='date', y='Ideal BOH > Shelf Capacity', source=source, fill_color=None, line_color="#4375c6") h.line(x='date', y='Ideal BOH > Shelf Capacity', source=source, hover_alpha=0.5, hover_line_color='black', line_width=2, line_color="#e84d60", legend=value("Ideal BOH > Shelf Capacity")) h.legend.border_line_width = 3 h.legend.border_line_color = None h.legend.border_line_alpha = 0.5 h.title.text_color = "olive" h.legend.click_policy = "mute" # Inventory hover_j = HoverTool( tooltips=[("Location", "@location_reference_id"), ( "Date", "@date_bar"), ("DFE_Q98", "@DFE_Q98{0,0.00}"), ("OTL", "@OTL{0,0.00}"), ("EOH", "@EOH{0,0.00}"), ("BOH", "@BOH{0,0.00}")]) TOOLS_j = [ hover_j, BoxZoomTool(), LassoSelectTool(), WheelZoomTool(), PanTool(), ResetTool(), SaveTool() ] j = figure(plot_height=525, plot_width=1200, x_range=x_range_list, title="Inbound/Outbound by store", tools=TOOLS_j, toolbar_location='above', x_axis_label="Date", y_axis_label="Units") j.background_fill_color = "#e6e9ed" j.background_fill_alpha = 0.5 j.vbar(x=dodge('date_bar', -0.40, range=j.x_range), top='DFE_Q98', hover_alpha=0.3, hover_line_color='black', width=0.2, source=source, color="#FBA40A", legend=value("DFE_Q98")) j.vbar(x=dodge('date_bar', -0.20, range=j.x_range), top='OTL', hover_alpha=0.3, hover_line_color='black', width=0.2, source=source, color="#4292c6", legend=value("OTL")) j.vbar(x=dodge('date_bar', 0.00, range=j.x_range), top='EOH', hover_alpha=0.3, hover_line_color='black', width=0.2, source=source, color='#a1dab4', legend=value("EOH")) j.vbar(x=dodge('date_bar', 0.20, range=j.x_range), top='BOH', hover_alpha=0.3, hover_line_color='black', width=0.2, source=source, color="#DC5039", legend=value("BOH")) j.xaxis.major_label_orientation = 1 j.legend.border_line_width = 3 j.legend.border_line_color = None j.legend.border_line_alpha = 0.5 j.title.text_color = "olive" j.legend.location = "top_left" j.legend.click_policy = "mute" #desc.text = " <br > <b> Region:</b> <i> </i> <br /> " pre.text = " <b>Start date:</b> <i>{}</i> <br /> <b>End date:</b> <i>{}</i> <br /> <b>Time period:</b> <i>{}</i> days <br /> <b> Total Number of Nodes:</b> <i>{}</i> <br /> <b>Policy</b> = <i>{}</i><br /> ".format( sdate, edate, days, nodes, policy) #fuction to update data on selection callback = CustomJS(args=dict(source=source, original_source=original_source, location_select_obj=location, region_select_obj=region, div=desc, text_input=text_input), code=""" var data = source.get('data'); var original_data = original_source.get('data'); var loc = location_select_obj.get('value'); var reg = region_select_obj.get('value'); var line = " <br /> <b> Region:</b>"+ reg + "<br /> <b>Location:</b> " + loc; var text_input =text_input.get('value'); div.text=line; for (var key in original_data) { data[key] = []; for (var i = 0; i < original_data['location_reference_id'].length; ++i) { if ((original_data['location_reference_id'][i] === loc) && (original_data['region'][i] === reg) ) { data[key].push(original_data[key][i]); } } } source.trigger('change'); """) #controls = [location, region] #for control in controls: #control.js_on_change("value", callback) #source.js_on_change("value", callback) desc.js_on_event('event', callback) location.js_on_change('value', callback) region.js_on_change('value', callback) text_input.js_on_change('value', callback) #inputs = widgetbox(*controls, sizing_mode="fixed") #inputs = widgetbox(*controls,width=220,height=500) inputs = widgetbox(location, region, desc, pre, width=220, height=500) # controls number of tabs tab1 = Panel(child=p, title='Backroom OH') tab2 = Panel(child=s, title='Stockouts') tab3 = Panel(child=f, title='BOH vs Ideal') tab4 = Panel(child=g, title='Pog Fit') tab5 = Panel(child=m, title='Inbound/Outbound') tab6 = Panel(child=h, title='BOH vs POG') tab7 = Panel(child=t, title='Fill Rate') tab8 = Panel(child=j, title='Inventory') tab9 = Panel(child=w, title='Prop OH in Backroom') #data table columns to summarize data columns = [ TableColumn(field="location_reference_id", title="Location"), TableColumn(field="Backroom_OH", title="Backroom_OH", formatter=NumberFormatter(format="0,0")), TableColumn(field="Outbound", title="Outbound", formatter=NumberFormatter(format="0,0")), TableColumn(field="Inbound", title="Inbound", formatter=NumberFormatter(format="0,0")), TableColumn(field="OTL", title="OTL", formatter=NumberFormatter(format="0,0")), TableColumn(field="DFE_Q98", title="DFE_Q98", formatter=NumberFormatter(format="0,0")), TableColumn(field="BOH", title="BOH", formatter=NumberFormatter(format="0,0")), TableColumn(field="EOH", title="EOH", formatter=NumberFormatter(format="0,0")), TableColumn(field="BOH_OOS", title="BOH_OOS", formatter=NumberFormatter(format="0,0")), TableColumn(field="EOH_OOS", title="EOH_OOS", formatter=NumberFormatter(format="0,0")) ] data_table = DataTable(source=source1, columns=columns, width=1250) tab10 = Panel(child=data_table, title='Summary Table') view = Tabs( tabs=[tab1, tab2, tab5, tab8, tab6, tab3, tab7, tab4, tab9, tab10]) layout_text = column(inputs) layout1 = row(layout_text, view) #laying out plot layout2 = layout(children=[[layout_text, view]], sizing_mode='scale_height') #update plots return layout2
def main(): print('''Please select the CSV dataset you\'d like to use. The dataset should contain these columns: - metric to apply threshold to - indicator of event to detect (e.g. malicious activity) - Please label this as 1 or 0 (true or false); This will not work otherwise! ''') # Import the dataset imported_data = None while isinstance(imported_data, pd.DataFrame) == False: file_path = input('Enter the path of your dataset: ') imported_data = file_to_df(file_path) time.sleep(1) print(f'''\nGreat! Here is a preview of your data: Imported fields:''') # List headers by column index. cols = list(imported_data.columns) for index in range(len(cols)): print(f'{index}: {cols[index]}') print(f'Number of records: {len(imported_data.index)}\n') # Preview the DataFrame time.sleep(1) print(imported_data.head(), '\n') # Prompt for the metric and source of truth. time.sleep(1) metric_col, indicator_col = columns_picker(cols) # User self-validation. col_check = input('Can you confirm if this is correct? (y/n): ').lower() # If it's wrong, let them try again while col_check != 'y': metric_col, indicator_col = columns_picker(cols) col_check = input( 'Can you confirm if this is correct? (y/n): ').lower() else: print( '''\nGreat! Thanks for your patience. Generating summary stats now..\n''' ) # Generate summary stats. time.sleep(1) malicious, normal = classification_split(imported_data, metric_col, indicator_col) mal_mean = malicious.mean() mal_stddev = malicious.std() mal_count = malicious.size mal_median = malicious.median() norm_mean = normal.mean() norm_stddev = normal.std() norm_count = normal.size norm_median = normal.median() print(f'''Normal vs Malicious Summary (metric = {metric_col}): Normal: ----------------------------- Observations: {round(norm_count, 2)} Average: {round(norm_mean, 2)} Median: {round(norm_median, 2)} Standard Deviation: {round(norm_stddev, 2)} Malicious: ----------------------------- Observations: {round(mal_count, 2)} Average: {round(mal_mean, 2)} Median: {round(mal_median, 2)} Standard Deviation: {round(mal_stddev, 2)} ''') # Insights and advisories # Provide the accuracy metrics of a generic threshold at avg + 3 std deviations generic_threshold = confusion_matrix( malicious, normal, threshold_calc(norm_mean, norm_stddev, 3)) time.sleep(1) print( f'''A threshold at (average + 3x standard deviations) {metric_col} would result in: - True Positives (correctly identified malicious events: {generic_threshold['TP']:,} - False Positives (wrongly identified normal events: {generic_threshold['FP']:,} - True Negatives (correctly identified normal events: {generic_threshold['TN']:,} - False Negatives (wrongly identified malicious events: {generic_threshold['FN']:,} Accuracy Metrics: - Precision (what % of events above threshold are actually malicious): {round(generic_threshold['precision'] * 100, 1)}% - Recall (what % of malicious events did we catch): {round(generic_threshold['recall'] * 100, 1)}% - F1 Score (blends precision and recall): {round(generic_threshold['f1_score'] * 100, 1)}%''' ) # Distribution skew check. if norm_mean >= (norm_median * 1.1): time.sleep(1) print( f'''\nYou may want to be cautious as your normal traffic\'s {metric_col} has a long tail towards high values. The median is {round(norm_median, 2)} compared to {round(norm_mean, 2)} for the average.''') if mal_mean < threshold_calc(norm_mean, norm_stddev, 2): time.sleep(1) print( f'''\nWarning: you may find it difficult to avoid false positives as the average {metric_col} for malicious traffic is under the 95th percentile of the normal traffic.''' ) # For fun/anticipation. Actually a nerd joke because of the method we'll be using. if '-q' not in sys.argv[1:]: time.sleep(1) play_a_game.billy() decision = input('yes/no: ').lower() while decision != 'yes': time.sleep(1) print('...That\'s no fun...') decision = input('Let\'s try that again: ').lower() # Let's get to the simulations! time.sleep(1) print('''\nInstead of manually experimenting with threshold multipliers, let\'s simulate a range of options and see what produces the best result. This is similar to what is known as \"Monte Carlo simulation\".\n''') # Initialize session name & create app folder if there isn't one. time.sleep(1) session_name = input('Please provide a name for this project/session: ') session_folder = make_folder(session_name) # Generate list of multipliers to iterate over. time.sleep(1) mult_start = float( input( 'Please provide the minimum multiplier you want to start at. We recommend 2: ' )) # Set the max to how many std deviations away the sample max is. mult_end = (imported_data[metric_col].max() - norm_mean) / norm_stddev mult_interval = float( input('Please provide the desired gap between multiplier options: ')) # range() only allows integers, let's manually populate a list multipliers = [] mult_counter = mult_start while mult_counter < mult_end: multipliers.append(round(mult_counter, 2)) mult_counter += mult_interval print('Generating simulations..\n') # Run simulations using our multipliers. simulations = monte_carlo(malicious, normal, norm_mean, norm_stddev, multipliers) print('Done!') time.sleep(1) # Save simulations as CSV for later use. simulation_filepath = os.path.join( session_folder, f'{session_name}_simulation_results.csv') simulations.to_csv(simulation_filepath, index=False) print(f'Saved results to: {simulation_filepath}') # Find the first threshold with the highest F1 score. # This provides a balanced approach between precision and recall. f1_max = simulations[simulations.f1_score == simulations.f1_score.max()].head(1) f1_max_mult = f1_max.squeeze()['multiplier'] time.sleep(1) print( f'''\nBased on the F1 score metric, setting a threshold at {round(f1_max_mult,1)} standard deviations above the average magnitude might provide optimal results.\n''') time.sleep(1) print(f'''{f1_max} We recommend that you skim the CSV and the following visualization outputs to sanity check results and make your own judgement. ''') # Now for the fun part..generating the visualizations via Bokeh. # Header & internal CSS. title_text = ''' <style> @font-face { font-family: RobotoBlack; src: url(fonts/Roboto-Black.ttf); font-weight: bold; } @font-face { font-family: RobotoBold; src: url(fonts/Roboto-Bold.ttf); font-weight: bold; } @font-face { font-family: RobotoRegular; src: url(fonts/Roboto-Regular.ttf); } body { background-color: #f2ebe6; } title_header { font-size: 80px; font-style: bold; font-family: RobotoBlack, Helvetica; font-weight: bold; margin-bottom: -200px; } h1, h2, h3 { font-family: RobotoBlack, Helvetica; color: #313596; } p { font-size: 12px; font-family: RobotoRegular } b { color: #58c491; } th, td { text-align:left; padding: 5px; } tr:nth-child(even) { background-color: white; opacity: .7; } .vertical { border-left: 1px solid black; height: 190px; } </style> <title_header style="text-align:left; color: white;"> Cream. </title_header> <p style="font-family: RobotoBold, Helvetica; font-size:18px; margin-top: 0px; margin-left: 5px;"> Because time is money, and <b style="font-size=18px;">"Cash Rules Everything Around Me"</b>. </p> </div> ''' title_div = Div(text=title_text, width=800, height=160, margin=(40, 0, 0, 70)) # Summary stats from earlier. summary_text = f''' <h1>Results Overview</h1> <i>metric = magnitude</i> <table style="width:100%"> <tr> <th>Metric</th> <th>Normal Events</th> <th>Malicious Events</th> </tr> <tr> <td>Observations</td> <td>{norm_count:,}</td> <td>{mal_count:,}</td> </tr> <tr> <td>Average</td> <td>{round(norm_mean, 2):,}</td> <td>{round(mal_mean, 2):,}</td> </tr> <tr> <td>Median</td> <td>{round(norm_median, 2):,}</td> <td>{round(mal_median, 2):,}</td> </tr> <tr> <td>Standard Deviation</td> <td>{round(norm_stddev, 2):,}</td> <td>{round(mal_stddev, 2):,}</td> </tr> </table> ''' summary_div = Div(text=summary_text, width=470, height=320, margin=(3, 0, -70, 73)) # Results of the hypothetical threshold. hypothetical = f''' <h1>"Rule of thumb" Hypothetical Threshold</h1> <p>A threshold at <i>(average + 3x standard deviations)</i> {metric_col} would result in:</p> <ul> <li>True Positives (correctly identified malicious events: <b>{generic_threshold['TP']:,}</b></li> <li>False Positives (wrongly identified normal events: <b>{generic_threshold['FP']:,}</b></li> <li>True Negatives (correctly identified normal events: <b>{generic_threshold['TN']:,}</b></li> <li>False Negatives (wrongly identified malicious events: <b>{generic_threshold['FN']:,}</b></li> </ul> <h2>Accuracy Metrics</h2> <ul> <li>Precision (what % of events above threshold are actually malicious): <b>{round(generic_threshold['precision'] * 100, 1)}%</b></li> <li>Recall (what % of malicious events did we catch): <b>{round(generic_threshold['recall'] * 100, 1)}%</b></li> <li>F1 Score (blends precision and recall): <b>{round(generic_threshold['f1_score'] * 100, 1)}%</b></li> </ul> ''' hypo_div = Div(text=hypothetical, width=600, height=320, margin=(5, 0, -70, 95)) line = ''' <div class="vertical"></div> ''' vertical_line = Div(text=line, width=20, height=320, margin=(80, 0, -70, -10)) # Let's get the exploratory charts generated. malicious_hist, malicious_edge = np.histogram(malicious, bins=100) mal_hist_df = pd.DataFrame({ 'metric': malicious_hist, 'left': malicious_edge[:-1], 'right': malicious_edge[1:] }) normal_hist, normal_edge = np.histogram(normal, bins=100) norm_hist_df = pd.DataFrame({ 'metric': normal_hist, 'left': normal_edge[:-1], 'right': normal_edge[1:] }) exploratory = figure( plot_width=plot_width, plot_height=plot_height, sizing_mode='fixed', title=f'{metric_col.capitalize()} Distribution (σ = std dev)', x_axis_label=f'{metric_col.capitalize()}', y_axis_label='Observations') exploratory.title.text_font_size = title_font_size exploratory.border_fill_color = cell_bg_color exploratory.border_fill_alpha = cell_bg_alpha exploratory.background_fill_color = cell_bg_color exploratory.background_fill_alpha = plot_bg_alpha exploratory.min_border_left = left_border exploratory.min_border_right = right_border exploratory.min_border_top = top_border exploratory.min_border_bottom = bottom_border exploratory.quad(bottom=0, top=mal_hist_df.metric, left=mal_hist_df.left, right=mal_hist_df.right, legend_label='malicious', fill_color=malicious_color, alpha=.85, line_alpha=.35, line_width=.5) exploratory.quad(bottom=0, top=norm_hist_df.metric, left=norm_hist_df.left, right=norm_hist_df.right, legend_label='normal', fill_color=normal_color, alpha=.35, line_alpha=.35, line_width=.5) exploratory.add_layout( Arrow(end=NormalHead(fill_color=malicious_color, size=10, line_alpha=0), line_color=malicious_color, x_start=mal_mean, y_start=mal_count, x_end=mal_mean, y_end=0)) arrow_label = Label(x=mal_mean, y=mal_count, y_offset=5, text='Malicious Events', text_font_style='bold', text_color=malicious_color, text_font_size='10pt') exploratory.add_layout(arrow_label) exploratory.xaxis.formatter = NumeralTickFormatter(format='0,0') exploratory.yaxis.formatter = NumeralTickFormatter(format='0,0') # 3 sigma reference line sigma_ref(exploratory, norm_mean, norm_stddev) exploratory.legend.location = "top_right" exploratory.legend.background_fill_alpha = .3 # Zoomed in version overlap_view = figure( plot_width=plot_width, plot_height=plot_height, sizing_mode='fixed', title=f'Overlap Highlight', x_axis_label=f'{metric_col.capitalize()}', y_axis_label='Observations', y_range=(0, mal_count * .33), x_range=(norm_mean + (norm_stddev * 2.5), mal_mean + (mal_stddev * 3)), ) overlap_view.title.text_font_size = title_font_size overlap_view.border_fill_color = cell_bg_color overlap_view.border_fill_alpha = cell_bg_alpha overlap_view.background_fill_color = cell_bg_color overlap_view.background_fill_alpha = plot_bg_alpha overlap_view.min_border_left = left_border overlap_view.min_border_right = right_border overlap_view.min_border_top = top_border overlap_view.min_border_bottom = bottom_border overlap_view.quad(bottom=0, top=mal_hist_df.metric, left=mal_hist_df.left, right=mal_hist_df.right, legend_label='malicious', fill_color=malicious_color, alpha=.85, line_alpha=.35, line_width=.5) overlap_view.quad(bottom=0, top=norm_hist_df.metric, left=norm_hist_df.left, right=norm_hist_df.right, legend_label='normal', fill_color=normal_color, alpha=.35, line_alpha=.35, line_width=.5) overlap_view.xaxis.formatter = NumeralTickFormatter(format='0,0') overlap_view.yaxis.formatter = NumeralTickFormatter(format='0,0') sigma_ref(overlap_view, norm_mean, norm_stddev) overlap_view.legend.location = "top_right" overlap_view.legend.background_fill_alpha = .3 # Probability Density - bigger bins for sparser malicous observations malicious_hist_dense, malicious_edge_dense = np.histogram(malicious, density=True, bins=50) mal_hist_dense_df = pd.DataFrame({ 'metric': malicious_hist_dense, 'left': malicious_edge_dense[:-1], 'right': malicious_edge_dense[1:] }) normal_hist_dense, normal_edge_dense = np.histogram(normal, density=True, bins=100) norm_hist_dense_df = pd.DataFrame({ 'metric': normal_hist_dense, 'left': normal_edge_dense[:-1], 'right': normal_edge_dense[1:] }) density = figure(plot_width=plot_width, plot_height=plot_height, sizing_mode='fixed', title='Probability Density', x_axis_label=f'{metric_col.capitalize()}', y_axis_label='% of Group Total') density.title.text_font_size = title_font_size density.border_fill_color = cell_bg_color density.border_fill_alpha = cell_bg_alpha density.background_fill_color = cell_bg_color density.background_fill_alpha = plot_bg_alpha density.min_border_left = left_border density.min_border_right = right_border density.min_border_top = top_border density.min_border_bottom = bottom_border density.quad(bottom=0, top=mal_hist_dense_df.metric, left=mal_hist_dense_df.left, right=mal_hist_dense_df.right, legend_label='malicious', fill_color=malicious_color, alpha=.85, line_alpha=.35, line_width=.5) density.quad(bottom=0, top=norm_hist_dense_df.metric, left=norm_hist_dense_df.left, right=norm_hist_dense_df.right, legend_label='normal', fill_color=normal_color, alpha=.35, line_alpha=.35, line_width=.5) density.xaxis.formatter = NumeralTickFormatter(format='0,0') density.yaxis.formatter = NumeralTickFormatter(format='0.000%') sigma_ref(density, norm_mean, norm_stddev) density.legend.location = "top_right" density.legend.background_fill_alpha = .3 # Simulation Series to be used false_positives = simulations.FP false_negatives = simulations.FN multiplier = simulations.multiplier precision = simulations.precision recall = simulations.recall f1_score = simulations.f1_score f1_max = simulations[simulations.f1_score == simulations.f1_score.max( )].head(1).squeeze()['multiplier'] # False Positives vs False Negatives errors = figure(plot_width=plot_width, plot_height=plot_height, sizing_mode='fixed', x_range=(multiplier.min(), multiplier.max()), y_range=(0, false_positives.max()), title='False Positives vs False Negatives', x_axis_label='Multiplier', y_axis_label='Count') errors.title.text_font_size = title_font_size errors.border_fill_color = cell_bg_color errors.border_fill_alpha = cell_bg_alpha errors.background_fill_color = cell_bg_color errors.background_fill_alpha = plot_bg_alpha errors.min_border_left = left_border errors.min_border_right = right_border errors.min_border_top = top_border errors.min_border_bottom = right_border errors.line(multiplier, false_positives, legend_label='false positives', line_width=2, color=fp_color) errors.line(multiplier, false_negatives, legend_label='false negatives', line_width=2, color=fn_color) errors.yaxis.formatter = NumeralTickFormatter(format='0,0') errors.extra_y_ranges = {"y2": Range1d(start=0, end=1.1)} errors.add_layout( LinearAxis(y_range_name="y2", axis_label="Score", formatter=NumeralTickFormatter(format='0.00%')), 'right') errors.line(multiplier, f1_score, line_width=2, color=f1_color, legend_label='F1 Score', y_range_name="y2") # F1 Score Maximization point f1_thresh = Span(location=f1_max, dimension='height', line_color=f1_color, line_dash='dashed', line_width=2) f1_label = Label(x=f1_max + .05, y=180, y_units='screen', text=f'F1 Max: {round(f1_max,2)}', text_font_size='10pt', text_font_style='bold', text_align='left', text_color=f1_color) errors.add_layout(f1_thresh) errors.add_layout(f1_label) errors.legend.location = "top_right" errors.legend.background_fill_alpha = .3 # False Negative Weighting. # Intro. weighting_intro = f''' <h3>Error types differ in impact.</h3> <p>In the case of security incidents, a false negative, though possibly rarer than false positives, is likely more costly. For example, downtime suffered from a DDoS attack (lost sales/customers) incurs more loss than time wasted chasing a false positive (labor hours). </p> <p>Try playing around with the slider to the right to see how your thresholding strategy might need to change depending on the relative weight of false negatives to false positives. What does it look like at 1:1, 50:1, etc.?</p> ''' weighting_div = Div(text=weighting_intro, width=420, height=180, margin=(0, 75, 0, 0)) # Now for the weighted errors viz default_weighting = 10 initial_fp_cost = 100 simulations['weighted_FN'] = simulations.FN * default_weighting weighted_fn = simulations.weighted_FN simulations[ 'total_weighted_error'] = simulations.FP + simulations.weighted_FN total_weighted_error = simulations.total_weighted_error simulations['fp_cost'] = initial_fp_cost fp_cost = simulations.fp_cost simulations[ 'total_estimated_cost'] = simulations.total_weighted_error * simulations.fp_cost total_estimated_cost = simulations.total_estimated_cost twe_min = simulations[simulations.total_weighted_error == simulations.total_weighted_error.min()].head( 1).squeeze()['multiplier'] twe_min_count = simulations[simulations.multiplier == twe_min].head( 1).squeeze()['total_weighted_error'] generic_twe = simulations[simulations.multiplier.apply( lambda x: round(x, 2)) == 3.00].squeeze()['total_weighted_error'] comparison = f''' <p>Based on your inputs, the optimal threshold is around <b>{twe_min}</b>. This would result in an estimated <b>{int(twe_min_count):,}</b> total weighted errors and <b>${int(twe_min_count * initial_fp_cost):,}</b> in losses.</p> <p>The generic threshold of 3.0 standard deviations would result in <b>{int(generic_twe):,}</b> total weighted errors and <b>${int(generic_twe * initial_fp_cost):,}</b> in losses.</p> <p>Using the optimal threshold would save <b>${int((generic_twe - twe_min_count) * initial_fp_cost):,}</b>, reducing costs by <b>{(generic_twe - twe_min_count) / generic_twe * 100:.1f}%</b> (assuming near-future events are distributed similarly to those from the past).</p> ''' comparison_div = Div(text=comparison, width=420, height=230, margin=(0, 75, 0, 0)) loss_min = ColumnDataSource(data=dict(multiplier=multiplier, fp=false_positives, fn=false_negatives, weighted_fn=weighted_fn, twe=total_weighted_error, fpc=fp_cost, tec=total_estimated_cost, precision=precision, recall=recall, f1=f1_score)) evaluation = Figure(plot_width=900, plot_height=520, sizing_mode='fixed', x_range=(multiplier.min(), multiplier.max()), title='Evaluation Metrics vs Total Estimated Cost', x_axis_label='Multiplier', y_axis_label='Cost') evaluation.title.text_font_size = title_font_size evaluation.border_fill_color = cell_bg_color evaluation.border_fill_alpha = cell_bg_alpha evaluation.background_fill_color = cell_bg_color evaluation.background_fill_alpha = plot_bg_alpha evaluation.min_border_left = left_border evaluation.min_border_right = right_border evaluation.min_border_top = top_border evaluation.min_border_bottom = bottom_border evaluation.line('multiplier', 'tec', source=loss_min, line_width=3, line_alpha=0.6, color=total_weighted_color, legend_label='Total Estimated Cost') evaluation.yaxis.formatter = NumeralTickFormatter(format='$0,0') # Evaluation metrics on second right axis. evaluation.extra_y_ranges = {"y2": Range1d(start=0, end=1.1)} evaluation.add_layout( LinearAxis(y_range_name="y2", axis_label="Score", formatter=NumeralTickFormatter(format='0.00%')), 'right') evaluation.line('multiplier', 'precision', source=loss_min, line_width=3, line_alpha=0.6, color=precision_color, legend_label='Precision', y_range_name="y2") evaluation.line('multiplier', 'recall', source=loss_min, line_width=3, line_alpha=0.6, color=recall_color, legend_label='Recall', y_range_name="y2") evaluation.line('multiplier', 'f1', source=loss_min, line_width=3, line_alpha=0.6, color=f1_color, legend_label='F1 score', y_range_name="y2") evaluation.legend.location = "bottom_right" evaluation.legend.background_fill_alpha = .3 twe_thresh = Span(location=twe_min, dimension='height', line_color=total_weighted_color, line_dash='dashed', line_width=2) twe_label = Label(x=twe_min - .05, y=240, y_units='screen', text=f'Cost Min: {round(twe_min,2)}', text_font_size='10pt', text_font_style='bold', text_align='right', text_color=total_weighted_color) evaluation.add_layout(twe_thresh) evaluation.add_layout(twe_label) # Add in same f1 thresh as previous viz evaluation.add_layout(f1_thresh) evaluation.add_layout(f1_label) handler = CustomJS(args=dict(source=loss_min, thresh=twe_thresh, label=twe_label, comparison=comparison_div), code=""" var data = source.data var ratio = cb_obj.value var multiplier = data['multiplier'] var fp = data['fp'] var fn = data['fn'] var weighted_fn = data['weighted_fn'] var twe = data['twe'] var fpc = data['fpc'] var tec = data['tec'] var generic_twe = 0 function round(value, decimals) { return Number(Math.round(value+'e'+decimals)+'e-'+decimals); } function comma_sep(x) { return x.toString().replace(/\B(?<!\.\d*)(?=(\d{3})+(?!\d))/g, ","); } for (var i = 0; i < multiplier.length; i++) { weighted_fn[i] = Math.round(fn[i] * ratio) twe[i] = weighted_fn[i] + fp[i] tec[i] = twe[i] * fpc[i] if (round(multiplier[i],2) == 3.00) { generic_twe = twe[i] } } var min_loss = Math.min.apply(null,twe) var new_thresh = 0 for (var i = 0; i < multiplier.length; i++) { if (twe[i] == min_loss) { new_thresh = multiplier[i] thresh.location = new_thresh thresh.change.emit() label.x = new_thresh label.text = `Cost Min: ${new_thresh}` label.change.emit() comparison.text = ` <p>Based on your inputs, the optimal threshold is around <b>${new_thresh}</b>. This would result in an estimated <b>${comma_sep(round(min_loss,0))}</b> total weighted errors and <b>$${comma_sep(round(min_loss * fpc[i],0))}</b> in losses.</p> <p>The generic threshold of 3.0 standard deviations would result in <b>${comma_sep(round(generic_twe,0))}</b> total weighted errors and <b>$${comma_sep(round(generic_twe * fpc[i],0))}</b> in losses.</p> <p>Using the optimal threshold would save <b>$${comma_sep(round((generic_twe - min_loss) * fpc[i],0))}</b>, reducing costs by <b>${comma_sep(round((generic_twe - min_loss) / generic_twe * 100,0))}%</b> (assuming near-future events are distributed similarly to those from the past).</p> ` comparison.change.emit() } } source.change.emit(); """) slider = Slider(start=1.0, end=500, value=default_weighting, step=.25, title="FN:FP Ratio", bar_color='#FFD100', height=50, margin=(5, 0, 5, 0)) slider.js_on_change('value', handler) cost_handler = CustomJS(args=dict(source=loss_min, comparison=comparison_div), code=""" var data = source.data var new_cost = cb_obj.value var multiplier = data['multiplier'] var fp = data['fp'] var fn = data['fn'] var weighted_fn = data['weighted_fn'] var twe = data['twe'] var fpc = data['fpc'] var tec = data['tec'] var generic_twe = 0 function round(value, decimals) { return Number(Math.round(value+'e'+decimals)+'e-'+decimals); } function comma_sep(x) { return x.toString().replace(/\B(?<!\.\d*)(?=(\d{3})+(?!\d))/g, ","); } for (var i = 0; i < multiplier.length; i++) { fpc[i] = new_cost tec[i] = twe[i] * fpc[i] if (round(multiplier[i],2) == 3.00) { generic_twe = twe[i] } } var min_loss = Math.min.apply(null,twe) var new_thresh = 0 for (var i = 0; i < multiplier.length; i++) { if (twe[i] == min_loss) { new_thresh = multiplier[i] comparison.text = ` <p>Based on your inputs, the optimal threshold is around <b>${new_thresh}</b>. This would result in an estimated <b>${comma_sep(round(min_loss,0))}</b> total weighted errors and <b>$${comma_sep(round(min_loss * new_cost,0))}</b> in losses.</p> <p>The generic threshold of 3.0 standard deviations would result in <b>${comma_sep(round(generic_twe,0))}</b> total weighted errors and <b>$${comma_sep(round(generic_twe * new_cost,0))}</b> in losses.</p> <p>Using the optimal threshold would save <b>$${comma_sep(round((generic_twe - min_loss) * new_cost,0))}</b>, reducing costs by <b>${comma_sep(round((generic_twe - min_loss)/generic_twe * 100,0))}%</b> (assuming near-future events are distributed similarly to those from the past).</p> ` comparison.change.emit() } } source.change.emit(); """) cost_input = TextInput(value=f"{initial_fp_cost}", title="How much a false positive costs:", height=75, margin=(20, 75, 20, 0)) cost_input.js_on_change('value', cost_handler) # Include DataTable of simulation results dt_columns = [ TableColumn(field="multiplier", title="Multiplier"), TableColumn(field="fp", title="False Positives", formatter=NumberFormatter(format='0,0')), TableColumn(field="fn", title="False Negatives", formatter=NumberFormatter(format='0,0')), TableColumn(field="weighted_fn", title="Weighted False Negatives", formatter=NumberFormatter(format='0,0.00')), TableColumn(field="twe", title="Total Weighted Errors", formatter=NumberFormatter(format='0,0.00')), TableColumn(field="fpc", title="Estimated FP Cost", formatter=NumberFormatter(format='$0,0.00')), TableColumn(field="tec", title="Estimated Total Cost", formatter=NumberFormatter(format='$0,0.00')), TableColumn(field="precision", title="Precision", formatter=NumberFormatter(format='0.00%')), TableColumn(field="recall", title="Recall", formatter=NumberFormatter(format='0.00%')), TableColumn(field="f1", title="F1 Score", formatter=NumberFormatter(format='0.00%')), ] data_table = DataTable(source=loss_min, columns=dt_columns, width=1400, height=700, sizing_mode='fixed', fit_columns=True, reorderable=True, sortable=True, margin=(30, 0, 20, 0)) # weighting_layout = column([weighting_div, evaluation, slider, data_table]) weighting_layout = column( row(column(weighting_div, cost_input, comparison_div), column(slider, evaluation), Div(text='', height=200, width=60)), data_table) # Initialize visualizations in browser time.sleep(1.5) layout = grid([ [title_div], [row(summary_div, vertical_line, hypo_div)], [ row(Div(text='', height=200, width=60), exploratory, Div(text='', height=200, width=10), overlap_view, Div(text='', height=200, width=40)) ], [Div(text='', height=10, width=200)], [ row(Div(text='', height=200, width=60), density, Div(text='', height=200, width=10), errors, Div(text='', height=200, width=40)) ], [Div(text='', height=10, width=200)], [ row(Div(text='', height=200, width=60), weighting_layout, Div(text='', height=200, width=40)) ], ]) # Generate html resources for dashboard fonts = os.path.join(os.getcwd(), 'fonts') if os.path.isdir(os.path.join(session_folder, 'fonts')): shutil.rmtree(os.path.join(session_folder, 'fonts')) shutil.copytree(fonts, os.path.join(session_folder, 'fonts')) else: shutil.copytree(fonts, os.path.join(session_folder, 'fonts')) html = file_html(layout, INLINE, "Cream") with open(os.path.join(session_folder, f'{session_name}.html'), "w") as file: file.write(html) webbrowser.open("file://" + os.path.join(session_folder, f'{session_name}.html'))
def bokeh_ajax(request): startDt = dfDict['time'][0].to_pydatetime() endDt = dfStatic['time'].iloc[-1].to_pydatetime() # note: need the below in order to display the bokeh plot jsResources = INLINE.render_js() # need the below in order to be able to properly interact with the plot and have the default bokeh plot # interaction tool to display cssResources = INLINE.render_css() source2 = ColumnDataSource(data={"time": [], "temperature": [], "id": []}) livePlot2 = figure(x_axis_type="datetime", x_range=[startDt, endDt], y_range=(0, 25), y_axis_label='Temperature (Celsius)', title="Sea Surface Temperature at 43.18, -70.43", plot_width=800) livePlot2.line("time", "temperature", source=source2) updateStartJS = CustomJS(args=dict(plotRange=livePlot2.x_range), code=""" var newStart = Date.parse(cb_obj.value) plotRange.start = newStart plotRange.change.emit() """) updateEndJS = CustomJS(args=dict(plotRange=livePlot2.x_range), code=""" var newEnd = Date.parse(cb_obj.value) plotRange.end = newEnd plotRange.change.emit() """) startInput = TextInput(value=startDt.strftime(dateFmt), title="Enter Date in format: YYYY-mm-dd") startInput.js_on_change('value', updateStartJS) endInput = TextInput(value=endDt.strftime(dateFmt), title="Enter Date in format: YYYY-mm-dd") endInput.js_on_change('value', updateEndJS) textWidgets = row(startInput, endInput) # https://stackoverflow.com/questions/37083998/flask-bokeh-ajaxdatasource # above stackoverflow helped a lot and is what the below CustomJS is based on callback = CustomJS(args=dict(source=source2), code=""" var time_values = "time"; var temperatures = "temperature"; var plot_data = source.data; jQuery.ajax({ type: 'POST', url: '/AJAXdata2', data: {}, dataType: 'json', success: function (json_from_server) { plot_data['temperature'] = plot_data['temperature'].concat(json_from_server['temperature']); plot_data['time'] = plot_data['time'].concat(json_from_server['time']); plot_data['id'] = plot_data['id'].concat(json_from_server['id']); source.change.emit(); }, error: function() { alert("Oh no, something went wrong. Search for an error " + "message in Flask log and browser developer tools."); } }); """) manualUpdate = Button(label="update graph", callback=callback) widgets = widgetbox([manualUpdate]) # IMPORTANT: key is that the widget you want to control plot X has to be in the same layout object as # said plot X . Therefore, when you call the components() method on it both widget and plot live within the # object, if they are not then the JS callbacks don't work because I think they do not know how to communicate # with one another layout2 = column(widgets, textWidgets, livePlot2) script2, div2 = components(layout2) return { 'someword': "hello", 'jsResources': jsResources, 'cssResources': cssResources, 'script2': script2, 'div2': div2 }
class ViewerVIWidgets(object): """ Encapsulates Bokeh widgets, and related callbacks, used for VI """ def __init__(self, title, viewer_cds): self.vi_quality_labels = [ x["label"] for x in vi_flags if x["type"]=="quality" ] self.vi_issue_labels = [ x["label"] for x in vi_flags if x["type"]=="issue" ] self.vi_issue_slabels = [ x["shortlabel"] for x in vi_flags if x["type"]=="issue" ] self.js_files = get_resources('js') self.title = title self.vi_countdown_toggle = None #- List of fields to be recorded in output csv file, contains for each field: # [ field name (in VI file header), associated variable in viewer_cds.cds_metadata ] self.output_file_fields = [] for file_field in vi_file_fields: if file_field[1] in viewer_cds.cds_metadata.data.keys() : self.output_file_fields.append([file_field[0], file_field[1]]) def add_filename(self, username='******'): #- VI file name default_vi_filename = "desi-vi_"+self.title default_vi_filename += ("_"+username) default_vi_filename += ".csv" self.vi_filename_input = TextInput(value=default_vi_filename, title="VI file name:") def add_vi_issues(self, viewer_cds, widgets): #- Optional VI flags (issues) self.vi_issue_input = CheckboxGroup(labels=self.vi_issue_labels, active=[]) vi_issue_code = self.js_files["CSVtoArray.js"] + self.js_files["save_vi.js"] vi_issue_code += """ var issues = [] for (var i=0; i<vi_issue_labels.length; i++) { if (vi_issue_input.active.indexOf(i) >= 0) issues.push(vi_issue_slabels[i]) } if (issues.length > 0) { cds_metadata.data['VI_issue_flag'][ifiberslider.value] = ( issues.join('') ) } else { cds_metadata.data['VI_issue_flag'][ifiberslider.value] = " " } autosave_vi_localStorage(output_file_fields, cds_metadata.data, title) cds_metadata.change.emit() """ self.vi_issue_callback = CustomJS( args=dict(cds_metadata = viewer_cds.cds_metadata, ifiberslider = widgets.ifiberslider, vi_issue_input = self.vi_issue_input, vi_issue_labels = self.vi_issue_labels, vi_issue_slabels = self.vi_issue_slabels, title = self.title, output_file_fields = self.output_file_fields), code = vi_issue_code ) self.vi_issue_input.js_on_click(self.vi_issue_callback) def add_vi_z(self, viewer_cds, widgets): ## TODO: z_tovi behaviour if with_vi_widget=False ..? #- Optional VI information on redshift self.vi_z_input = TextInput(value='', title="VI redshift:") vi_z_code = self.js_files["CSVtoArray.js"] + self.js_files["save_vi.js"] vi_z_code += """ cds_metadata.data['VI_z'][ifiberslider.value]=vi_z_input.value autosave_vi_localStorage(output_file_fields, cds_metadata.data, title) cds_metadata.change.emit() """ self.vi_z_callback = CustomJS( args=dict(cds_metadata = viewer_cds.cds_metadata, ifiberslider = widgets.ifiberslider, vi_z_input = self.vi_z_input, title = self.title, output_file_fields=self.output_file_fields), code = vi_z_code ) self.vi_z_input.js_on_change('value', self.vi_z_callback) # Copy z value from redshift slider to VI self.z_tovi_button = Button(label='Copy z to VI') self.z_tovi_callback = CustomJS( args=dict(z_input=widgets.z_input, vi_z_input=self.vi_z_input), code=""" vi_z_input.value = z_input.value """) self.z_tovi_button.js_on_event('button_click', self.z_tovi_callback) def add_vi_spectype(self, viewer_cds, widgets): #- Optional VI information on spectral type self.vi_category_select = Select(value=' ', title="VI spectype:", options=([' '] + vi_spectypes)) # The default value set to ' ' as setting value='' does not seem to work well with Select. vi_category_code = self.js_files["CSVtoArray.js"] + self.js_files["save_vi.js"] vi_category_code += """ if (vi_category_select.value == ' ') { cds_metadata.data['VI_spectype'][ifiberslider.value]='' } else { cds_metadata.data['VI_spectype'][ifiberslider.value]=vi_category_select.value } autosave_vi_localStorage(output_file_fields, cds_metadata.data, title) cds_metadata.change.emit() """ self.vi_category_callback = CustomJS( args=dict(cds_metadata=viewer_cds.cds_metadata, ifiberslider = widgets.ifiberslider, vi_category_select=self.vi_category_select, title=self.title, output_file_fields=self.output_file_fields), code=vi_category_code ) self.vi_category_select.js_on_change('value', self.vi_category_callback) def add_vi_comment(self, viewer_cds, widgets): #- Optional VI comment self.vi_comment_input = TextInput(value='', title="VI comment (see guidelines):") vi_comment_code = self.js_files["CSVtoArray.js"] + self.js_files["save_vi.js"] vi_comment_code += """ var stored_comment = (vi_comment_input.value).replace(/./g, function(char){ if ( char==',' ) { return ';' } else if ( char.charCodeAt(0)<=127 ) { return char } else { var char_list = ['Å','α','β','γ','δ','λ'] var char_replace = ['Angstrom','alpha','beta','gamma','delta','lambda'] for (var i=0; i<char_list.length; i++) { if ( char==char_list[i] ) return char_replace[i] } return '?' } }) cds_metadata.data['VI_comment'][ifiberslider.value] = stored_comment autosave_vi_localStorage(output_file_fields, cds_metadata.data, title) cds_metadata.change.emit() """ self.vi_comment_callback = CustomJS( args=dict(cds_metadata = viewer_cds.cds_metadata, ifiberslider = widgets.ifiberslider, vi_comment_input = self.vi_comment_input, title=self.title, output_file_fields=self.output_file_fields), code=vi_comment_code ) self.vi_comment_input.js_on_change('value',self.vi_comment_callback) #- List of "standard" VI comment self.vi_std_comment_select = Select(value=" ", title="Standard comment:", options=([' '] + vi_std_comments)) vi_std_comment_code = """ if (vi_std_comment_select.value != ' ') { if (vi_comment_input.value != '') { vi_comment_input.value = vi_comment_input.value + " " + vi_std_comment_select.value } else { vi_comment_input.value = vi_std_comment_select.value } } """ self.vi_std_comment_callback = CustomJS( args = dict(vi_std_comment_select = self.vi_std_comment_select, vi_comment_input = self.vi_comment_input), code = vi_std_comment_code ) self.vi_std_comment_select.js_on_change('value', self.vi_std_comment_callback) def add_vi_quality(self, viewer_cds, widgets): #- Main VI quality widget self.vi_quality_input = RadioButtonGroup(labels=self.vi_quality_labels) vi_quality_code = self.js_files["CSVtoArray.js"] + self.js_files["save_vi.js"] vi_quality_code += """ if ( vi_quality_input.active >= 0 ) { cds_metadata.data['VI_quality_flag'][ifiberslider.value] = vi_quality_labels[vi_quality_input.active] } else { cds_metadata.data['VI_quality_flag'][ifiberslider.value] = "-1" } autosave_vi_localStorage(output_file_fields, cds_metadata.data, title) cds_metadata.change.emit() """ self.vi_quality_callback = CustomJS( args = dict(cds_metadata = viewer_cds.cds_metadata, vi_quality_input = self.vi_quality_input, vi_quality_labels = self.vi_quality_labels, ifiberslider = widgets.ifiberslider, title=self.title, output_file_fields = self.output_file_fields), code=vi_quality_code ) self.vi_quality_input.js_on_click(self.vi_quality_callback) def add_vi_scanner(self, viewer_cds): #- VI scanner name self.vi_name_input = TextInput(value=(viewer_cds.cds_metadata.data['VI_scanner'][0]).strip(), title="Your name (3-letter acronym):") vi_name_code = self.js_files["CSVtoArray.js"] + self.js_files["save_vi.js"] vi_name_code += """ for (var i=0; i<(cds_metadata.data['VI_scanner']).length; i++) { cds_metadata.data['VI_scanner'][i]=vi_name_input.value } var newname = vi_filename_input.value var name_chunks = newname.split("_") newname = ( name_chunks.slice(0,name_chunks.length-1).join("_") ) + ("_"+vi_name_input.value+".csv") vi_filename_input.value = newname autosave_vi_localStorage(output_file_fields, cds_metadata.data, title) """ self.vi_name_callback = CustomJS( args = dict(cds_metadata = viewer_cds.cds_metadata, vi_name_input = self.vi_name_input, vi_filename_input = self.vi_filename_input, title=self.title, output_file_fields = self.output_file_fields), code=vi_name_code ) self.vi_name_input.js_on_change('value', self.vi_name_callback) def add_guidelines(self): #- Guidelines for VI flags vi_guideline_txt = "<B> VI guidelines </B>" vi_guideline_txt += "<BR /> <B> Classification flags: </B>" for flag in vi_flags : if flag['type'] == 'quality' : vi_guideline_txt += ("<BR />  [ "+flag['label']+" ] "+flag['description']) vi_guideline_txt += "<BR /> <B> Optional indications: </B>" for flag in vi_flags : if flag['type'] == 'issue' : vi_guideline_txt += ( "<BR />  [ " + flag['label'] + " (" + flag['shortlabel'] + ") ] " + flag['description'] ) vi_guideline_txt += "<BR /> <B> Comments: </B> <BR /> 100 characters max, avoid commas (automatically replaced by semi-columns), ASCII only." self.vi_guideline_div = Div(text=vi_guideline_txt) def add_vi_storage(self, viewer_cds, widgets): #- Save VI info to CSV file self.save_vi_button = Button(label="Download VI", button_type="success") save_vi_code = self.js_files["FileSaver.js"] + self.js_files["CSVtoArray.js"] + self.js_files["save_vi.js"] save_vi_code += """ download_vi_file(output_file_fields, cds_metadata.data, vi_filename_input.value) """ self.save_vi_callback = CustomJS( args = dict(cds_metadata = viewer_cds.cds_metadata, output_file_fields = self.output_file_fields, vi_filename_input = self.vi_filename_input), code=save_vi_code ) self.save_vi_button.js_on_event('button_click', self.save_vi_callback) #- Recover auto-saved VI data in browser self.recover_vi_button = Button(label="Recover auto-saved VI", button_type="default") recover_vi_code = self.js_files["CSVtoArray.js"] + self.js_files["recover_autosave_vi.js"] self.recover_vi_callback = CustomJS( args = dict(title=self.title, output_file_fields=self.output_file_fields, cds_metadata = viewer_cds.cds_metadata, ifiber = widgets.ifiberslider.value, vi_comment_input = self.vi_comment_input, vi_name_input=self.vi_name_input, vi_quality_input=self.vi_quality_input, vi_issue_input=self.vi_issue_input, vi_issue_slabels=self.vi_issue_slabels, vi_quality_labels=self.vi_quality_labels), code = recover_vi_code ) self.recover_vi_button.js_on_event('button_click', self.recover_vi_callback) #- Clear all auto-saved VI self.clear_vi_button = Button(label="Clear all auto-saved VI", button_type="default") self.clear_vi_callback = CustomJS( args = dict(), code = """ localStorage.clear() """ ) self.clear_vi_button.js_on_event('button_click', self.clear_vi_callback) def add_vi_table(self, viewer_cds): #- Show VI in a table vi_table_columns = [ TableColumn(field="VI_quality_flag", title="Flag", width=40), TableColumn(field="VI_issue_flag", title="Opt.", width=50), TableColumn(field="VI_z", title="VI z", width=50), TableColumn(field="VI_spectype", title="VI spectype", width=150), TableColumn(field="VI_comment", title="VI comment", width=200) ] self.vi_table = DataTable(source=viewer_cds.cds_metadata, columns=vi_table_columns, width=500) self.vi_table.height = 10 * self.vi_table.row_height def add_countdown(self, vi_countdown): #- VI countdown assert vi_countdown > 0 self.vi_countdown_callback = CustomJS(args=dict(vi_countdown=vi_countdown), code=""" if ( (cb_obj.label).includes('Start') ) { // Callback doesn't do anything after countdown started var countDownDate = new Date().getTime() + (1000 * 60 * vi_countdown); var countDownLoop = setInterval(function(){ var now = new Date().getTime(); var distance = countDownDate - now; if (distance<0) { cb_obj.label = "Time's up !"; clearInterval(countDownLoop); } else { var days = Math.floor(distance / (1000 * 60 * 60 * 24)); var hours = Math.floor((distance % (1000 * 60 * 60 * 24)) / (1000 * 60 * 60)); var minutes = Math.floor((distance % (1000 * 60 * 60)) / (1000 * 60)); var seconds = Math.floor((distance % (1000 * 60)) / 1000); //var stuff = days + "d " + hours + "h " + minutes + "m " + seconds + "s "; var stuff = minutes + "m " + seconds + "s "; cb_obj.label = "Countdown: " + stuff; } }, 1000); } """) self.vi_countdown_toggle = Toggle(label='Start countdown ('+str(vi_countdown)+' min)', active=False, button_type="success") self.vi_countdown_toggle.js_on_change('active', self.vi_countdown_callback)
def photometry_plot(obj_id, user, width=600, device="browser"): """Create object photometry scatter plot. Parameters ---------- obj_id : str ID of Obj to be plotted. Returns ------- dict Returns Bokeh JSON embedding for the desired plot. """ data = pd.read_sql( DBSession() .query( Photometry, Telescope.nickname.label("telescope"), Instrument.name.label("instrument"), ) .join(Instrument, Instrument.id == Photometry.instrument_id) .join(Telescope, Telescope.id == Instrument.telescope_id) .filter(Photometry.obj_id == obj_id) .filter( Photometry.groups.any(Group.id.in_([g.id for g in user.accessible_groups])) ) .statement, DBSession().bind, ) if data.empty: return None, None, None # get spectra to annotate on phot plots spectra = ( Spectrum.query_records_accessible_by(user) .filter(Spectrum.obj_id == obj_id) .all() ) data['color'] = [get_color(f) for f in data['filter']] # get marker for each unique instrument instruments = list(data.instrument.unique()) markers = [] for i, inst in enumerate(instruments): markers.append(phot_markers[i % len(phot_markers)]) filters = list(set(data['filter'])) colors = [get_color(f) for f in filters] color_mapper = CategoricalColorMapper(factors=filters, palette=colors) color_dict = {'field': 'filter', 'transform': color_mapper} labels = [] for i, datarow in data.iterrows(): label = f'{datarow["instrument"]}/{datarow["filter"]}' if datarow['origin'] is not None: label += f'/{datarow["origin"]}' labels.append(label) data['label'] = labels data['zp'] = PHOT_ZP data['magsys'] = 'ab' data['alpha'] = 1.0 data['lim_mag'] = ( -2.5 * np.log10(data['fluxerr'] * PHOT_DETECTION_THRESHOLD) + data['zp'] ) # Passing a dictionary to a bokeh datasource causes the frontend to die, # deleting the dictionary column fixes that del data['original_user_data'] # keep track of things that are only upper limits data['hasflux'] = ~data['flux'].isna() # calculate the magnitudes - a photometry point is considered "significant" # or "detected" (and thus can be represented by a magnitude) if its snr # is above PHOT_DETECTION_THRESHOLD obsind = data['hasflux'] & ( data['flux'].fillna(0.0) / data['fluxerr'] >= PHOT_DETECTION_THRESHOLD ) data.loc[~obsind, 'mag'] = None data.loc[obsind, 'mag'] = -2.5 * np.log10(data[obsind]['flux']) + PHOT_ZP # calculate the magnitude errors using standard error propagation formulae # https://en.wikipedia.org/wiki/Propagation_of_uncertainty#Example_formulae data.loc[~obsind, 'magerr'] = None coeff = 2.5 / np.log(10) magerrs = np.abs(coeff * data[obsind]['fluxerr'] / data[obsind]['flux']) data.loc[obsind, 'magerr'] = magerrs data['obs'] = obsind data['stacked'] = False split = data.groupby('label', sort=False) finite = np.isfinite(data['flux']) fdata = data[finite] lower = np.min(fdata['flux']) * 0.95 upper = np.max(fdata['flux']) * 1.05 xmin = data['mjd'].min() - 2 xmax = data['mjd'].max() + 2 # Layout parameters based on device type active_drag = None if "mobile" in device or "tablet" in device else "box_zoom" tools = ( 'box_zoom,pan,reset' if "mobile" in device or "tablet" in device else "box_zoom,wheel_zoom,pan,reset,save" ) legend_loc = "below" if "mobile" in device or "tablet" in device else "right" legend_orientation = ( "vertical" if device in ["browser", "mobile_portrait"] else "horizontal" ) # Compute a plot component height based on rough number of legend rows added below the plot # Values are based on default sizing of bokeh components and an estimate of how many # legend items would fit on the average device screen. Note that the legend items per # row is computed more exactly later once labels are extracted from the data (with the # add_plot_legend() function). # # The height is manually computed like this instead of using built in aspect_ratio/sizing options # because with the new Interactive Legend approach (instead of the legacy CheckboxLegendGroup), the # Legend component is considered part of the plot and plays into the sizing computations. Since the # number of items in the legend can alter the needed heights of the plot, using built-in Bokeh options # for sizing does not allow for keeping the actual graph part of the plot at a consistent aspect ratio. # # For the frame width, by default we take the desired plot width minus 64 for the y-axis/label taking # up horizontal space frame_width = width - 64 if device == "mobile_portrait": legend_items_per_row = 1 legend_row_height = 24 aspect_ratio = 1 elif device == "mobile_landscape": legend_items_per_row = 4 legend_row_height = 50 aspect_ratio = 1.8 elif device == "tablet_portrait": legend_items_per_row = 5 legend_row_height = 50 aspect_ratio = 1.5 elif device == "tablet_landscape": legend_items_per_row = 7 legend_row_height = 50 aspect_ratio = 1.8 elif device == "browser": # Width minus some base width for the legend, which is only a column to the right # for browser mode frame_width = width - 200 height = ( 500 if device == "browser" else math.floor(width / aspect_ratio) + legend_row_height * int(len(split) / legend_items_per_row) + 30 # 30 is the height of the toolbar ) plot = figure( frame_width=frame_width, height=height, active_drag=active_drag, tools=tools, toolbar_location='above', toolbar_sticky=True, y_range=(lower, upper), min_border_right=16, x_axis_location='above', sizing_mode="stretch_width", ) plot.xaxis.axis_label = 'MJD' now = Time.now().mjd plot.extra_x_ranges = {"Days Ago": Range1d(start=now - xmin, end=now - xmax)} plot.add_layout(LinearAxis(x_range_name="Days Ago", axis_label="Days Ago"), 'below') imhover = HoverTool(tooltips=tooltip_format) imhover.renderers = [] plot.add_tools(imhover) model_dict = {} legend_items = [] for i, (label, sdf) in enumerate(split): renderers = [] # for the flux plot, we only show things that have a flux value df = sdf[sdf['hasflux']] key = f'obs{i}' model_dict[key] = plot.scatter( x='mjd', y='flux', color='color', marker=factor_mark('instrument', markers, instruments), fill_color=color_dict, alpha='alpha', source=ColumnDataSource(df), ) renderers.append(model_dict[key]) imhover.renderers.append(model_dict[key]) key = f'bin{i}' model_dict[key] = plot.scatter( x='mjd', y='flux', color='color', marker=factor_mark('instrument', markers, instruments), fill_color=color_dict, source=ColumnDataSource( data=dict( mjd=[], flux=[], fluxerr=[], filter=[], color=[], lim_mag=[], mag=[], magerr=[], stacked=[], instrument=[], ) ), ) renderers.append(model_dict[key]) imhover.renderers.append(model_dict[key]) key = 'obserr' + str(i) y_err_x = [] y_err_y = [] for d, ro in df.iterrows(): px = ro['mjd'] py = ro['flux'] err = ro['fluxerr'] y_err_x.append((px, px)) y_err_y.append((py - err, py + err)) model_dict[key] = plot.multi_line( xs='xs', ys='ys', color='color', alpha='alpha', source=ColumnDataSource( data=dict( xs=y_err_x, ys=y_err_y, color=df['color'], alpha=[1.0] * len(df) ) ), ) renderers.append(model_dict[key]) key = f'binerr{i}' model_dict[key] = plot.multi_line( xs='xs', ys='ys', color='color', # legend_label=label, source=ColumnDataSource(data=dict(xs=[], ys=[], color=[])), ) renderers.append(model_dict[key]) legend_items.append(LegendItem(label=label, renderers=renderers)) if device == "mobile_portrait": plot.xaxis.ticker.desired_num_ticks = 5 plot.yaxis.axis_label = 'Flux (μJy)' plot.toolbar.logo = None add_plot_legend(plot, legend_items, width, legend_orientation, legend_loc) slider = Slider( start=0.0, end=15.0, value=0.0, step=1.0, title='Binsize (days)', max_width=350, margin=(4, 10, 0, 10), ) callback = CustomJS( args={'slider': slider, 'n_labels': len(split), **model_dict}, code=open( os.path.join(os.path.dirname(__file__), '../static/js/plotjs', 'stackf.js') ) .read() .replace('default_zp', str(PHOT_ZP)) .replace('detect_thresh', str(PHOT_DETECTION_THRESHOLD)), ) slider.js_on_change('value', callback) # Mark the first and last detections detection_dates = data[data['hasflux']]['mjd'] if len(detection_dates) > 0: first = round(detection_dates.min(), 6) last = round(detection_dates.max(), 6) first_color = "#34b4eb" last_color = "#8992f5" midpoint = (upper + lower) / 2 line_top = 5 * upper - 4 * midpoint line_bottom = 5 * lower - 4 * midpoint y = np.linspace(line_bottom, line_top, num=5000) first_r = plot.line( x=np.full(5000, first), y=y, line_alpha=0.5, line_color=first_color, line_width=2, ) plot.add_tools( HoverTool( tooltips=[("First detection", f'{first}')], renderers=[first_r], ) ) last_r = plot.line( x=np.full(5000, last), y=y, line_alpha=0.5, line_color=last_color, line_width=2, ) plot.add_tools( HoverTool( tooltips=[("Last detection", f'{last}')], renderers=[last_r], ) ) # Mark when spectra were taken annotate_spec(plot, spectra, lower, upper) layout = column(slider, plot, width=width, height=height) p1 = Panel(child=layout, title='Flux') # now make the mag light curve ymax = ( np.nanmax( ( np.nanmax(data.loc[obsind, 'mag']) if any(obsind) else np.nan, np.nanmax(data.loc[~obsind, 'lim_mag']) if any(~obsind) else np.nan, ) ) + 0.1 ) ymin = ( np.nanmin( ( np.nanmin(data.loc[obsind, 'mag']) if any(obsind) else np.nan, np.nanmin(data.loc[~obsind, 'lim_mag']) if any(~obsind) else np.nan, ) ) - 0.1 ) plot = figure( frame_width=frame_width, height=height, active_drag=active_drag, tools=tools, y_range=(ymax, ymin), x_range=(xmin, xmax), toolbar_location='above', toolbar_sticky=True, x_axis_location='above', sizing_mode="stretch_width", ) plot.xaxis.axis_label = 'MJD' now = Time.now().mjd plot.extra_x_ranges = {"Days Ago": Range1d(start=now - xmin, end=now - xmax)} plot.add_layout(LinearAxis(x_range_name="Days Ago", axis_label="Days Ago"), 'below') obj = DBSession().query(Obj).get(obj_id) if obj.dm is not None: plot.extra_y_ranges = { "Absolute Mag": Range1d(start=ymax - obj.dm, end=ymin - obj.dm) } plot.add_layout( LinearAxis(y_range_name="Absolute Mag", axis_label="m - DM"), 'right' ) # Mark the first and last detections again detection_dates = data[obsind]['mjd'] if len(detection_dates) > 0: first = round(detection_dates.min(), 6) last = round(detection_dates.max(), 6) midpoint = (ymax + ymin) / 2 line_top = 5 * ymax - 4 * midpoint line_bottom = 5 * ymin - 4 * midpoint y = np.linspace(line_bottom, line_top, num=5000) first_r = plot.line( x=np.full(5000, first), y=y, line_alpha=0.5, line_color=first_color, line_width=2, ) plot.add_tools( HoverTool( tooltips=[("First detection", f'{first}')], renderers=[first_r], ) ) last_r = plot.line( x=np.full(5000, last), y=y, line_alpha=0.5, line_color=last_color, line_width=2, ) plot.add_tools( HoverTool( tooltips=[("Last detection", f'{last}')], renderers=[last_r], point_policy='follow_mouse', ) ) # Mark when spectra were taken annotate_spec(plot, spectra, ymax, ymin) imhover = HoverTool(tooltips=tooltip_format) imhover.renderers = [] plot.add_tools(imhover) model_dict = {} # Legend items are individually stored instead of being applied # directly when plotting so that they can be separated into multiple # Legend() components if needed (to simulate horizontal row wrapping). # This is necessary because Bokeh does not support row wrapping with # horizontally-oriented legends out-of-the-box. legend_items = [] for i, (label, df) in enumerate(split): renderers = [] key = f'obs{i}' model_dict[key] = plot.scatter( x='mjd', y='mag', color='color', marker=factor_mark('instrument', markers, instruments), fill_color=color_dict, alpha='alpha', source=ColumnDataSource(df[df['obs']]), ) renderers.append(model_dict[key]) imhover.renderers.append(model_dict[key]) unobs_source = df[~df['obs']].copy() unobs_source.loc[:, 'alpha'] = 0.8 key = f'unobs{i}' model_dict[key] = plot.scatter( x='mjd', y='lim_mag', color=color_dict, marker='inverted_triangle', fill_color='white', line_color='color', alpha='alpha', source=ColumnDataSource(unobs_source), ) renderers.append(model_dict[key]) imhover.renderers.append(model_dict[key]) key = f'bin{i}' model_dict[key] = plot.scatter( x='mjd', y='mag', color=color_dict, marker=factor_mark('instrument', markers, instruments), fill_color='color', source=ColumnDataSource( data=dict( mjd=[], flux=[], fluxerr=[], filter=[], color=[], lim_mag=[], mag=[], magerr=[], instrument=[], stacked=[], ) ), ) renderers.append(model_dict[key]) imhover.renderers.append(model_dict[key]) key = 'obserr' + str(i) y_err_x = [] y_err_y = [] for d, ro in df[df['obs']].iterrows(): px = ro['mjd'] py = ro['mag'] err = ro['magerr'] y_err_x.append((px, px)) y_err_y.append((py - err, py + err)) model_dict[key] = plot.multi_line( xs='xs', ys='ys', color='color', alpha='alpha', source=ColumnDataSource( data=dict( xs=y_err_x, ys=y_err_y, color=df[df['obs']]['color'], alpha=[1.0] * len(df[df['obs']]), ) ), ) renderers.append(model_dict[key]) key = f'binerr{i}' model_dict[key] = plot.multi_line( xs='xs', ys='ys', color='color', source=ColumnDataSource(data=dict(xs=[], ys=[], color=[])), ) renderers.append(model_dict[key]) key = f'unobsbin{i}' model_dict[key] = plot.scatter( x='mjd', y='lim_mag', color='color', marker='inverted_triangle', fill_color='white', line_color=color_dict, alpha=0.8, source=ColumnDataSource( data=dict( mjd=[], flux=[], fluxerr=[], filter=[], color=[], lim_mag=[], mag=[], magerr=[], instrument=[], stacked=[], ) ), ) imhover.renderers.append(model_dict[key]) renderers.append(model_dict[key]) key = f'all{i}' model_dict[key] = ColumnDataSource(df) key = f'bold{i}' model_dict[key] = ColumnDataSource( df[ [ 'mjd', 'flux', 'fluxerr', 'mag', 'magerr', 'filter', 'zp', 'magsys', 'lim_mag', 'stacked', ] ] ) legend_items.append(LegendItem(label=label, renderers=renderers)) add_plot_legend(plot, legend_items, width, legend_orientation, legend_loc) plot.yaxis.axis_label = 'AB mag' plot.toolbar.logo = None slider = Slider( start=0.0, end=15.0, value=0.0, step=1.0, title='Binsize (days)', max_width=350, margin=(4, 10, 0, 10), ) button = Button(label="Export Bold Light Curve to CSV") button.js_on_click( CustomJS( args={'slider': slider, 'n_labels': len(split), **model_dict}, code=open( os.path.join( os.path.dirname(__file__), '../static/js/plotjs', "download.js" ) ) .read() .replace('objname', obj_id) .replace('default_zp', str(PHOT_ZP)), ) ) # Don't need to expose CSV download on mobile top_layout = ( slider if "mobile" in device or "tablet" in device else row(slider, button) ) callback = CustomJS( args={'slider': slider, 'n_labels': len(split), **model_dict}, code=open( os.path.join(os.path.dirname(__file__), '../static/js/plotjs', 'stackm.js') ) .read() .replace('default_zp', str(PHOT_ZP)) .replace('detect_thresh', str(PHOT_DETECTION_THRESHOLD)), ) slider.js_on_change('value', callback) layout = column(top_layout, plot, width=width, height=height) p2 = Panel(child=layout, title='Mag') # now make period plot # get periods from annotations annotation_list = obj.get_annotations_readable_by(user) period_labels = [] period_list = [] for an in annotation_list: if 'period' in an.data: period_list.append(an.data['period']) period_labels.append(an.origin + ": %.9f" % an.data['period']) if len(period_list) > 0: period = period_list[0] else: period = None # don't generate if no period annotated if period is not None: # bokeh figure for period plotting period_plot = figure( frame_width=frame_width, height=height, active_drag=active_drag, tools=tools, y_range=(ymax, ymin), x_range=(-0.01, 2.01), # initially one phase toolbar_location='above', toolbar_sticky=False, x_axis_location='below', sizing_mode="stretch_width", ) # axis labels period_plot.xaxis.axis_label = 'phase' period_plot.yaxis.axis_label = 'mag' period_plot.toolbar.logo = None # do we have a distance modulus (dm)? obj = DBSession().query(Obj).get(obj_id) if obj.dm is not None: period_plot.extra_y_ranges = { "Absolute Mag": Range1d(start=ymax - obj.dm, end=ymin - obj.dm) } period_plot.add_layout( LinearAxis(y_range_name="Absolute Mag", axis_label="m - DM"), 'right' ) # initiate hover tool period_imhover = HoverTool(tooltips=tooltip_format) period_imhover.renderers = [] period_plot.add_tools(period_imhover) # initiate period radio buttons period_selection = RadioGroup(labels=period_labels, active=0) phase_selection = RadioGroup(labels=["One phase", "Two phases"], active=1) # store all the plot data period_model_dict = {} # iterate over each filter legend_items = [] for i, (label, df) in enumerate(split): renderers = [] # fold x-axis on period in days df['mjd_folda'] = (df['mjd'] % period) / period df['mjd_foldb'] = df['mjd_folda'] + 1.0 # phase plotting for ph in ['a', 'b']: key = 'fold' + ph + f'{i}' period_model_dict[key] = period_plot.scatter( x='mjd_fold' + ph, y='mag', color='color', marker=factor_mark('instrument', markers, instruments), fill_color=color_dict, alpha='alpha', # visible=('a' in ph), source=ColumnDataSource(df[df['obs']]), # only visible data ) # add to hover tool period_imhover.renderers.append(period_model_dict[key]) renderers.append(period_model_dict[key]) # errorbars for phases key = 'fold' + ph + f'err{i}' y_err_x = [] y_err_y = [] # get each visible error value for d, ro in df[df['obs']].iterrows(): px = ro['mjd_fold' + ph] py = ro['mag'] err = ro['magerr'] # set up error tuples y_err_x.append((px, px)) y_err_y.append((py - err, py + err)) # plot phase errors period_model_dict[key] = period_plot.multi_line( xs='xs', ys='ys', color='color', alpha='alpha', # visible=('a' in ph), source=ColumnDataSource( data=dict( xs=y_err_x, ys=y_err_y, color=df[df['obs']]['color'], alpha=[1.0] * len(df[df['obs']]), ) ), ) renderers.append(period_model_dict[key]) legend_items.append(LegendItem(label=label, renderers=renderers)) add_plot_legend( period_plot, legend_items, width, legend_orientation, legend_loc ) # set up period adjustment text box period_title = Div(text="Period (days): ") period_textinput = TextInput(value=str(period if period is not None else 0.0)) period_textinput.js_on_change( 'value', CustomJS( args={ 'textinput': period_textinput, 'numphases': phase_selection, 'n_labels': len(split), 'p': period_plot, **period_model_dict, }, code=open( os.path.join( os.path.dirname(__file__), '../static/js/plotjs', 'foldphase.js' ) ).read(), ), ) # a way to modify the period period_double_button = Button(label="*2", width=30) period_double_button.js_on_click( CustomJS( args={'textinput': period_textinput}, code=""" const period = parseFloat(textinput.value); textinput.value = parseFloat(2.*period).toFixed(9); """, ) ) period_halve_button = Button(label="/2", width=30) period_halve_button.js_on_click( CustomJS( args={'textinput': period_textinput}, code=""" const period = parseFloat(textinput.value); textinput.value = parseFloat(period/2.).toFixed(9); """, ) ) # a way to select the period period_selection.js_on_click( CustomJS( args={'textinput': period_textinput, 'periods': period_list}, code=""" textinput.value = parseFloat(periods[this.active]).toFixed(9); """, ) ) phase_selection.js_on_click( CustomJS( args={ 'textinput': period_textinput, 'numphases': phase_selection, 'n_labels': len(split), 'p': period_plot, **period_model_dict, }, code=open( os.path.join( os.path.dirname(__file__), '../static/js/plotjs', 'foldphase.js' ) ).read(), ) ) # layout if device == "mobile_portrait": period_controls = column( row( period_title, period_textinput, period_double_button, period_halve_button, width=width, sizing_mode="scale_width", ), phase_selection, period_selection, width=width, ) # Add extra height to plot based on period control components added # 18 is the height of each period selection radio option (per default font size) # and the 130 encompasses the other components which are consistent no matter # the data size. height += 130 + 18 * len(period_labels) else: period_controls = column( row( period_title, period_textinput, period_double_button, period_halve_button, phase_selection, width=width, sizing_mode="scale_width", ), period_selection, margin=10, ) # Add extra height to plot based on period control components added # Numbers are derived in similar manner to the "mobile_portrait" case above height += 90 + 18 * len(period_labels) period_layout = column(period_plot, period_controls, width=width, height=height) # Period panel p3 = Panel(child=period_layout, title='Period') # tabs for mag, flux, period tabs = Tabs(tabs=[p2, p1, p3], width=width, height=height, sizing_mode='fixed') else: # tabs for mag, flux tabs = Tabs(tabs=[p2, p1], width=width, height=height + 90, sizing_mode='fixed') return bokeh_embed.json_item(tabs)
def spectroscopy_plot(obj_id, user, spec_id=None, width=600, device="browser"): obj = Obj.query.get(obj_id) spectra = ( DBSession() .query(Spectrum) .join(Obj) .join(GroupSpectrum) .filter( Spectrum.obj_id == obj_id, GroupSpectrum.group_id.in_([g.id for g in user.accessible_groups]), ) ).all() if spec_id is not None: spectra = [spec for spec in spectra if spec.id == int(spec_id)] if len(spectra) == 0: return None, None, None rainbow = cm.get_cmap('rainbow', len(spectra)) palette = list(map(rgb2hex, rainbow(range(len(spectra))))) color_map = dict(zip([s.id for s in spectra], palette)) data = [] for i, s in enumerate(spectra): # normalize spectra to a median flux of 1 for easy comparison normfac = np.nanmedian(np.abs(s.fluxes)) normfac = normfac if normfac != 0.0 else 1e-20 df = pd.DataFrame( { 'wavelength': s.wavelengths, 'flux': s.fluxes / normfac, 'id': s.id, 'telescope': s.instrument.telescope.name, 'instrument': s.instrument.name, 'date_observed': s.observed_at.isoformat(sep=' ', timespec='seconds'), 'pi': ( s.assignment.run.pi if s.assignment is not None else ( s.followup_request.allocation.pi if s.followup_request is not None else "" ) ), } ) data.append(df) data = pd.concat(data) data.sort_values(by=['date_observed', 'wavelength'], inplace=True) dfs = [] for i, s in enumerate(spectra): # Smooth the spectrum by using a rolling average df = ( pd.DataFrame({'wavelength': s.wavelengths, 'flux': s.fluxes}) .rolling(2) .mean(numeric_only=True) .dropna() ) dfs.append(df) smoothed_data = pd.concat(dfs) split = data.groupby('id', sort=False) hover = HoverTool( tooltips=[ ('wavelength', '@wavelength{0,0.000}'), ('flux', '@flux'), ('telesecope', '@telescope'), ('instrument', '@instrument'), ('UTC date observed', '@date_observed'), ('PI', '@pi'), ] ) smoothed_max = np.max(smoothed_data['flux']) smoothed_min = np.min(smoothed_data['flux']) ymax = smoothed_max * 1.05 ymin = smoothed_min - 0.05 * (smoothed_max - smoothed_min) xmin = np.min(data['wavelength']) - 100 xmax = np.max(data['wavelength']) + 100 if obj.redshift is not None and obj.redshift > 0: xmin_rest = xmin / (1.0 + obj.redshift) xmax_rest = xmax / (1.0 + obj.redshift) active_drag = None if "mobile" in device or "tablet" in device else "box_zoom" tools = ( "box_zoom, pan, reset" if "mobile" in device or "tablet" in device else "box_zoom,wheel_zoom,pan,reset" ) # These values are equivalent from the photometry plot values frame_width = width - 64 if device == "mobile_portrait": legend_items_per_row = 1 legend_row_height = 24 aspect_ratio = 1 elif device == "mobile_landscape": legend_items_per_row = 4 legend_row_height = 50 aspect_ratio = 1.8 elif device == "tablet_portrait": legend_items_per_row = 5 legend_row_height = 50 aspect_ratio = 1.5 elif device == "tablet_landscape": legend_items_per_row = 7 legend_row_height = 50 aspect_ratio = 1.8 elif device == "browser": frame_width = width - 200 plot_height = ( 400 if device == "browser" else math.floor(width / aspect_ratio) + legend_row_height * int(len(split) / legend_items_per_row) + 30 # 30 is the height of the toolbar ) plot = figure( frame_width=frame_width, height=plot_height, y_range=(ymin, ymax), x_range=(xmin, xmax), tools=tools, toolbar_location="above", active_drag=active_drag, ) plot.add_tools(hover) model_dict = {} legend_items = [] for i, (key, df) in enumerate(split): renderers = [] s = Spectrum.query.get(key) label = f'{s.instrument.name} ({s.observed_at.date().strftime("%m/%d/%y")})' model_dict['s' + str(i)] = plot.step( x='wavelength', y='flux', color=color_map[key], source=ColumnDataSource(df), ) renderers.append(model_dict['s' + str(i)]) legend_items.append(LegendItem(label=label, renderers=renderers)) model_dict['l' + str(i)] = plot.line( x='wavelength', y='flux', color=color_map[key], source=ColumnDataSource(df), line_alpha=0.0, ) plot.xaxis.axis_label = 'Wavelength (Å)' plot.yaxis.axis_label = 'Flux' plot.toolbar.logo = None if obj.redshift is not None and obj.redshift > 0: plot.extra_x_ranges = {"rest_wave": Range1d(start=xmin_rest, end=xmax_rest)} plot.add_layout( LinearAxis(x_range_name="rest_wave", axis_label="Rest Wavelength (Å)"), 'above', ) # TODO how to choose a good default? plot.y_range = Range1d(0, 1.03 * data.flux.max()) legend_loc = "below" if "mobile" in device or "tablet" in device else "right" legend_orientation = ( "vertical" if device in ["browser", "mobile_portrait"] else "horizontal" ) add_plot_legend(plot, legend_items, width, legend_orientation, legend_loc) # 20 is for padding slider_width = width if "mobile" in device else int(width / 2) - 20 z_title = Div(text="Redshift (<i>z</i>): ") z_slider = Slider( value=obj.redshift if obj.redshift is not None else 0.0, start=0.0, end=3.0, step=0.001, show_value=False, format="0[.]000", ) z_textinput = TextInput( value=str(obj.redshift if obj.redshift is not None else 0.0) ) z_slider.js_on_change( 'value', CustomJS( args={'slider': z_slider, 'textinput': z_textinput}, code=""" textinput.value = parseFloat(slider.value).toFixed(3); textinput.change.emit(); """, ), ) z = column( z_title, z_slider, z_textinput, width=slider_width, margin=(4, 10, 0, 10), ) v_title = Div(text="<i>V</i><sub>expansion</sub> (km/s): ") v_exp_slider = Slider( value=0.0, start=0.0, end=3e4, step=10.0, show_value=False, ) v_exp_textinput = TextInput(value='0') v_exp_slider.js_on_change( 'value', CustomJS( args={'slider': v_exp_slider, 'textinput': v_exp_textinput}, code=""" textinput.value = parseFloat(slider.value).toFixed(0); textinput.change.emit(); """, ), ) v_exp = column( v_title, v_exp_slider, v_exp_textinput, width=slider_width, margin=(0, 10, 0, 10), ) for i, (wavelengths, color) in enumerate(SPEC_LINES.values()): el_data = pd.DataFrame({'wavelength': wavelengths}) obj_redshift = 0 if obj.redshift is None else obj.redshift el_data['x'] = el_data['wavelength'] * (1.0 + obj_redshift) model_dict[f'el{i}'] = plot.segment( x0='x', x1='x', # TODO change limits y0=0, y1=1e4, color=color, source=ColumnDataSource(el_data), ) model_dict[f'el{i}'].visible = False # Split spectral line legend into columns if device == "mobile_portrait": columns = 3 elif device == "mobile_landscape": columns = 5 else: columns = 7 element_dicts = zip(*itertools.zip_longest(*[iter(SPEC_LINES.items())] * columns)) elements_groups = [] # The Bokeh checkbox groups callbacks = [] # The checkbox callbacks for each element for column_idx, element_dict in enumerate(element_dicts): element_dict = [e for e in element_dict if e is not None] labels = [key for key, value in element_dict] colors = [c for key, (w, c) in element_dict] elements = CheckboxWithLegendGroup( labels=labels, active=[], colors=colors, width=width // (columns + 1) ) elements_groups.append(elements) callback = CustomJS( args={ 'elements': elements, 'z': z_textinput, 'v_exp': v_exp_textinput, **model_dict, }, code=f""" let c = 299792.458; // speed of light in km / s const i_max = {column_idx} + {columns} * elements.labels.length; let local_i = 0; for (let i = {column_idx}; i < i_max; i = i + {columns}) {{ let el = eval("el" + i); el.visible = (elements.active.includes(local_i)) el.data_source.data.x = el.data_source.data.wavelength.map( x_i => (x_i * (1 + parseFloat(z.value)) / (1 + parseFloat(v_exp.value) / c)) ); el.data_source.change.emit(); local_i++; }} """, ) elements.js_on_click(callback) callbacks.append(callback) z_textinput.js_on_change( 'value', CustomJS( args={ 'z': z_textinput, 'slider': z_slider, 'v_exp': v_exp_textinput, **model_dict, }, code=""" // Update slider value to match text input slider.value = parseFloat(z.value).toFixed(3); """, ), ) v_exp_textinput.js_on_change( 'value', CustomJS( args={ 'z': z_textinput, 'slider': v_exp_slider, 'v_exp': v_exp_textinput, **model_dict, }, code=""" // Update slider value to match text input slider.value = parseFloat(v_exp.value).toFixed(3); """, ), ) # Update the element spectral lines as well for callback in callbacks: z_textinput.js_on_change('value', callback) v_exp_textinput.js_on_change('value', callback) # Add some height for the checkboxes and sliders if device == "mobile_portrait": height = plot_height + 400 elif device == "mobile_landscape": height = plot_height + 350 else: height = plot_height + 200 row2 = row(elements_groups) row3 = column(z, v_exp) if "mobile" in device else row(z, v_exp) layout = column( plot, row2, row3, sizing_mode='stretch_width', width=width, height=height, ) return bokeh_embed.json_item(layout)
def photometry_plot(obj_id, user, width=600, height=300, device="browser"): """Create object photometry scatter plot. Parameters ---------- obj_id : str ID of Obj to be plotted. Returns ------- dict Returns Bokeh JSON embedding for the desired plot. """ data = pd.read_sql( DBSession().query( Photometry, Telescope.nickname.label("telescope"), Instrument.name.label("instrument"), ).join(Instrument, Instrument.id == Photometry.instrument_id).join( Telescope, Telescope.id == Instrument.telescope_id).filter( Photometry.obj_id == obj_id).filter( Photometry.groups.any( Group.id.in_([g.id for g in user.accessible_groups ]))).statement, DBSession().bind, ) if data.empty: return None, None, None # get spectra to annotate on phot plots spectra = (Spectrum.query_records_accessible_by(user).filter( Spectrum.obj_id == obj_id).all()) data['color'] = [get_color(f) for f in data['filter']] labels = [] for i, datarow in data.iterrows(): label = f'{datarow["instrument"]}/{datarow["filter"]}' if datarow['origin'] is not None: label += f'/{datarow["origin"]}' labels.append(label) data['label'] = labels data['zp'] = PHOT_ZP data['magsys'] = 'ab' data['alpha'] = 1.0 data['lim_mag'] = ( -2.5 * np.log10(data['fluxerr'] * PHOT_DETECTION_THRESHOLD) + data['zp']) # Passing a dictionary to a bokeh datasource causes the frontend to die, # deleting the dictionary column fixes that del data['original_user_data'] # keep track of things that are only upper limits data['hasflux'] = ~data['flux'].isna() # calculate the magnitudes - a photometry point is considered "significant" # or "detected" (and thus can be represented by a magnitude) if its snr # is above PHOT_DETECTION_THRESHOLD obsind = data['hasflux'] & (data['flux'].fillna(0.0) / data['fluxerr'] >= PHOT_DETECTION_THRESHOLD) data.loc[~obsind, 'mag'] = None data.loc[obsind, 'mag'] = -2.5 * np.log10(data[obsind]['flux']) + PHOT_ZP # calculate the magnitude errors using standard error propagation formulae # https://en.wikipedia.org/wiki/Propagation_of_uncertainty#Example_formulae data.loc[~obsind, 'magerr'] = None coeff = 2.5 / np.log(10) magerrs = np.abs(coeff * data[obsind]['fluxerr'] / data[obsind]['flux']) data.loc[obsind, 'magerr'] = magerrs data['obs'] = obsind data['stacked'] = False split = data.groupby('label', sort=False) finite = np.isfinite(data['flux']) fdata = data[finite] lower = np.min(fdata['flux']) * 0.95 upper = np.max(fdata['flux']) * 1.05 active_drag = None if "mobile" in device or "tablet" in device else "box_zoom" tools = ('box_zoom,pan,reset' if "mobile" in device or "tablet" in device else "box_zoom,wheel_zoom,pan,reset,save") plot = figure( aspect_ratio=2.0 if device == "mobile_landscape" else 1.5, sizing_mode='scale_both', active_drag=active_drag, tools=tools, toolbar_location='above', toolbar_sticky=True, y_range=(lower, upper), min_border_right=16, ) imhover = HoverTool(tooltips=tooltip_format) imhover.renderers = [] plot.add_tools(imhover) model_dict = {} for i, (label, sdf) in enumerate(split): # for the flux plot, we only show things that have a flux value df = sdf[sdf['hasflux']] key = f'obs{i}' model_dict[key] = plot.scatter( x='mjd', y='flux', color='color', marker='circle', fill_color='color', alpha='alpha', source=ColumnDataSource(df), ) imhover.renderers.append(model_dict[key]) key = f'bin{i}' model_dict[key] = plot.scatter( x='mjd', y='flux', color='color', marker='circle', fill_color='color', source=ColumnDataSource(data=dict( mjd=[], flux=[], fluxerr=[], filter=[], color=[], lim_mag=[], mag=[], magerr=[], stacked=[], instrument=[], )), ) imhover.renderers.append(model_dict[key]) key = 'obserr' + str(i) y_err_x = [] y_err_y = [] for d, ro in df.iterrows(): px = ro['mjd'] py = ro['flux'] err = ro['fluxerr'] y_err_x.append((px, px)) y_err_y.append((py - err, py + err)) model_dict[key] = plot.multi_line( xs='xs', ys='ys', color='color', alpha='alpha', source=ColumnDataSource(data=dict(xs=y_err_x, ys=y_err_y, color=df['color'], alpha=[1.0] * len(df))), ) key = f'binerr{i}' model_dict[key] = plot.multi_line( xs='xs', ys='ys', color='color', source=ColumnDataSource(data=dict(xs=[], ys=[], color=[])), ) plot.xaxis.axis_label = 'MJD' if device == "mobile_portrait": plot.xaxis.ticker.desired_num_ticks = 5 plot.yaxis.axis_label = 'Flux (μJy)' plot.toolbar.logo = None colors_labels = data[['color', 'label']].drop_duplicates() toggle = CheckboxWithLegendGroup( labels=colors_labels.label.tolist(), active=list(range(len(colors_labels))), colors=colors_labels.color.tolist(), width=width // 5, inline=True if "tablet" in device else False, ) # TODO replace `eval` with Namespaces # https://github.com/bokeh/bokeh/pull/6340 toggle.js_on_click( CustomJS( args={ 'toggle': toggle, **model_dict }, code=open( os.path.join(os.path.dirname(__file__), '../static/js/plotjs', 'togglef.js')).read(), )) slider = Slider( start=0.0, end=15.0, value=0.0, step=1.0, title='Binsize (days)', max_width=350, margin=(4, 10, 0, 10), ) callback = CustomJS( args={ 'slider': slider, 'toggle': toggle, **model_dict }, code=open( os.path.join(os.path.dirname(__file__), '../static/js/plotjs', 'stackf.js')).read().replace( 'default_zp', str(PHOT_ZP)).replace( 'detect_thresh', str(PHOT_DETECTION_THRESHOLD)), ) slider.js_on_change('value', callback) # Mark the first and last detections detection_dates = data[data['hasflux']]['mjd'] if len(detection_dates) > 0: first = round(detection_dates.min(), 6) last = round(detection_dates.max(), 6) first_color = "#34b4eb" last_color = "#8992f5" midpoint = (upper + lower) / 2 line_top = 5 * upper - 4 * midpoint line_bottom = 5 * lower - 4 * midpoint y = np.linspace(line_bottom, line_top, num=5000) first_r = plot.line( x=np.full(5000, first), y=y, line_alpha=0.5, line_color=first_color, line_width=2, ) plot.add_tools( HoverTool( tooltips=[("First detection", f'{first}')], renderers=[first_r], )) last_r = plot.line( x=np.full(5000, last), y=y, line_alpha=0.5, line_color=last_color, line_width=2, ) plot.add_tools( HoverTool( tooltips=[("Last detection", f'{last}')], renderers=[last_r], )) # Mark when spectra were taken annotate_spec(plot, spectra, lower, upper) plot_layout = (column(plot, toggle) if "mobile" in device or "tablet" in device else row(plot, toggle)) layout = column(slider, plot_layout, sizing_mode='scale_width', width=width) p1 = Panel(child=layout, title='Flux') # now make the mag light curve ymax = (np.nanmax(( np.nanmax(data.loc[obsind, 'mag']) if any(obsind) else np.nan, np.nanmax(data.loc[~obsind, 'lim_mag']) if any(~obsind) else np.nan, )) + 0.1) ymin = (np.nanmin(( np.nanmin(data.loc[obsind, 'mag']) if any(obsind) else np.nan, np.nanmin(data.loc[~obsind, 'lim_mag']) if any(~obsind) else np.nan, )) - 0.1) xmin = data['mjd'].min() - 2 xmax = data['mjd'].max() + 2 plot = figure( aspect_ratio=2.0 if device == "mobile_landscape" else 1.5, sizing_mode='scale_both', width=width, active_drag=active_drag, tools=tools, y_range=(ymax, ymin), x_range=(xmin, xmax), toolbar_location='above', toolbar_sticky=True, x_axis_location='above', ) # Mark the first and last detections again detection_dates = data[obsind]['mjd'] if len(detection_dates) > 0: first = round(detection_dates.min(), 6) last = round(detection_dates.max(), 6) midpoint = (ymax + ymin) / 2 line_top = 5 * ymax - 4 * midpoint line_bottom = 5 * ymin - 4 * midpoint y = np.linspace(line_bottom, line_top, num=5000) first_r = plot.line( x=np.full(5000, first), y=y, line_alpha=0.5, line_color=first_color, line_width=2, ) plot.add_tools( HoverTool( tooltips=[("First detection", f'{first}')], renderers=[first_r], )) last_r = plot.line( x=np.full(5000, last), y=y, line_alpha=0.5, line_color=last_color, line_width=2, ) plot.add_tools( HoverTool( tooltips=[("Last detection", f'{last}')], renderers=[last_r], point_policy='follow_mouse', )) # Mark when spectra were taken annotate_spec(plot, spectra, ymax, ymin) imhover = HoverTool(tooltips=tooltip_format) imhover.renderers = [] plot.add_tools(imhover) model_dict = {} for i, (label, df) in enumerate(split): key = f'obs{i}' model_dict[key] = plot.scatter( x='mjd', y='mag', color='color', marker='circle', fill_color='color', alpha='alpha', source=ColumnDataSource(df[df['obs']]), ) imhover.renderers.append(model_dict[key]) unobs_source = df[~df['obs']].copy() unobs_source.loc[:, 'alpha'] = 0.8 key = f'unobs{i}' model_dict[key] = plot.scatter( x='mjd', y='lim_mag', color='color', marker='inverted_triangle', fill_color='white', line_color='color', alpha='alpha', source=ColumnDataSource(unobs_source), ) imhover.renderers.append(model_dict[key]) key = f'bin{i}' model_dict[key] = plot.scatter( x='mjd', y='mag', color='color', marker='circle', fill_color='color', source=ColumnDataSource(data=dict( mjd=[], flux=[], fluxerr=[], filter=[], color=[], lim_mag=[], mag=[], magerr=[], instrument=[], stacked=[], )), ) imhover.renderers.append(model_dict[key]) key = 'obserr' + str(i) y_err_x = [] y_err_y = [] for d, ro in df[df['obs']].iterrows(): px = ro['mjd'] py = ro['mag'] err = ro['magerr'] y_err_x.append((px, px)) y_err_y.append((py - err, py + err)) model_dict[key] = plot.multi_line( xs='xs', ys='ys', color='color', alpha='alpha', source=ColumnDataSource(data=dict( xs=y_err_x, ys=y_err_y, color=df[df['obs']]['color'], alpha=[1.0] * len(df[df['obs']]), )), ) key = f'binerr{i}' model_dict[key] = plot.multi_line( xs='xs', ys='ys', color='color', source=ColumnDataSource(data=dict(xs=[], ys=[], color=[])), ) key = f'unobsbin{i}' model_dict[key] = plot.scatter( x='mjd', y='lim_mag', color='color', marker='inverted_triangle', fill_color='white', line_color='color', alpha=0.8, source=ColumnDataSource(data=dict( mjd=[], flux=[], fluxerr=[], filter=[], color=[], lim_mag=[], mag=[], magerr=[], instrument=[], stacked=[], )), ) imhover.renderers.append(model_dict[key]) key = f'all{i}' model_dict[key] = ColumnDataSource(df) key = f'bold{i}' model_dict[key] = ColumnDataSource(df[[ 'mjd', 'flux', 'fluxerr', 'mag', 'magerr', 'filter', 'zp', 'magsys', 'lim_mag', 'stacked', ]]) plot.xaxis.axis_label = 'MJD' plot.yaxis.axis_label = 'AB mag' plot.toolbar.logo = None obj = DBSession().query(Obj).get(obj_id) if obj.dm is not None: plot.extra_y_ranges = { "Absolute Mag": Range1d(start=ymax - obj.dm, end=ymin - obj.dm) } plot.add_layout( LinearAxis(y_range_name="Absolute Mag", axis_label="m - DM"), 'right') now = Time.now().mjd plot.extra_x_ranges = { "Days Ago": Range1d(start=now - xmin, end=now - xmax) } plot.add_layout(LinearAxis(x_range_name="Days Ago", axis_label="Days Ago"), 'below') colors_labels = data[['color', 'label']].drop_duplicates() toggle = CheckboxWithLegendGroup( labels=colors_labels.label.tolist(), active=list(range(len(colors_labels))), colors=colors_labels.color.tolist(), width=width // 5, inline=True if "tablet" in device else False, ) # TODO replace `eval` with Namespaces # https://github.com/bokeh/bokeh/pull/6340 toggle.js_on_click( CustomJS( args={ 'toggle': toggle, **model_dict }, code=open( os.path.join(os.path.dirname(__file__), '../static/js/plotjs', 'togglem.js')).read(), )) slider = Slider( start=0.0, end=15.0, value=0.0, step=1.0, title='Binsize (days)', max_width=350, margin=(4, 10, 0, 10), ) button = Button(label="Export Bold Light Curve to CSV") button.js_on_click( CustomJS( args={ 'slider': slider, 'toggle': toggle, **model_dict }, code=open( os.path.join(os.path.dirname(__file__), '../static/js/plotjs', "download.js")).read().replace('objname', obj_id).replace( 'default_zp', str(PHOT_ZP)), )) # Don't need to expose CSV download on mobile top_layout = (slider if "mobile" in device or "tablet" in device else row( slider, button)) callback = CustomJS( args={ 'slider': slider, 'toggle': toggle, **model_dict }, code=open( os.path.join(os.path.dirname(__file__), '../static/js/plotjs', 'stackm.js')).read().replace( 'default_zp', str(PHOT_ZP)).replace( 'detect_thresh', str(PHOT_DETECTION_THRESHOLD)), ) slider.js_on_change('value', callback) plot_layout = (column(plot, toggle) if "mobile" in device or "tablet" in device else row(plot, toggle)) layout = column(top_layout, plot_layout, sizing_mode='scale_width', width=width) p2 = Panel(child=layout, title='Mag') # now make period plot # get periods from annotations annotation_list = obj.get_annotations_readable_by(user) period_labels = [] period_list = [] for an in annotation_list: if 'period' in an.data: period_list.append(an.data['period']) period_labels.append(an.origin + ": %.9f" % an.data['period']) if len(period_list) > 0: period = period_list[0] else: period = None # don't generate if no period annotated if period is not None: # bokeh figure for period plotting period_plot = figure( aspect_ratio=1.5, sizing_mode='scale_both', active_drag='box_zoom', tools='box_zoom,wheel_zoom,pan,reset,save', y_range=(ymax, ymin), x_range=(-0.1, 1.1), # initially one phase toolbar_location='above', toolbar_sticky=False, x_axis_location='below', ) # axis labels period_plot.xaxis.axis_label = 'phase' period_plot.yaxis.axis_label = 'mag' period_plot.toolbar.logo = None # do we have a distance modulus (dm)? obj = DBSession().query(Obj).get(obj_id) if obj.dm is not None: period_plot.extra_y_ranges = { "Absolute Mag": Range1d(start=ymax - obj.dm, end=ymin - obj.dm) } period_plot.add_layout( LinearAxis(y_range_name="Absolute Mag", axis_label="m - DM"), 'right') # initiate hover tool period_imhover = HoverTool(tooltips=tooltip_format) period_imhover.renderers = [] period_plot.add_tools(period_imhover) # initiate period radio buttons period_selection = RadioGroup(labels=period_labels, active=0) phase_selection = RadioGroup(labels=["One phase", "Two phases"], active=0) # store all the plot data period_model_dict = {} # iterate over each filter for i, (label, df) in enumerate(split): # fold x-axis on period in days df['mjd_folda'] = (df['mjd'] % period) / period df['mjd_foldb'] = df['mjd_folda'] + 1.0 # phase plotting for ph in ['a', 'b']: key = 'fold' + ph + f'{i}' period_model_dict[key] = period_plot.scatter( x='mjd_fold' + ph, y='mag', color='color', marker='circle', fill_color='color', alpha='alpha', visible=('a' in ph), source=ColumnDataSource( df[df['obs']]), # only visible data ) # add to hover tool period_imhover.renderers.append(period_model_dict[key]) # errorbars for phases key = 'fold' + ph + f'err{i}' y_err_x = [] y_err_y = [] # get each visible error value for d, ro in df[df['obs']].iterrows(): px = ro['mjd_fold' + ph] py = ro['mag'] err = ro['magerr'] # set up error tuples y_err_x.append((px, px)) y_err_y.append((py - err, py + err)) # plot phase errors period_model_dict[key] = period_plot.multi_line( xs='xs', ys='ys', color='color', alpha='alpha', visible=('a' in ph), source=ColumnDataSource(data=dict( xs=y_err_x, ys=y_err_y, color=df[df['obs']]['color'], alpha=[1.0] * len(df[df['obs']]), )), ) # toggle for folded photometry period_toggle = CheckboxWithLegendGroup( labels=colors_labels.label.tolist(), active=list(range(len(colors_labels))), colors=colors_labels.color.tolist(), width=width // 5, ) # use javascript to perform toggling on click # TODO replace `eval` with Namespaces # https://github.com/bokeh/bokeh/pull/6340 period_toggle.js_on_click( CustomJS( args={ 'toggle': period_toggle, 'numphases': phase_selection, 'p': period_plot, **period_model_dict, }, code=open( os.path.join(os.path.dirname(__file__), '../static/js/plotjs', 'togglep.js')).read(), )) # set up period adjustment text box period_title = Div(text="Period (days): ") period_textinput = TextInput( value=str(period if period is not None else 0.0)) period_textinput.js_on_change( 'value', CustomJS( args={ 'textinput': period_textinput, 'toggle': period_toggle, 'numphases': phase_selection, 'p': period_plot, **period_model_dict, }, code=open( os.path.join(os.path.dirname(__file__), '../static/js/plotjs', 'foldphase.js')).read(), ), ) # a way to modify the period period_double_button = Button(label="*2") period_double_button.js_on_click( CustomJS( args={'textinput': period_textinput}, code=""" const period = parseFloat(textinput.value); textinput.value = parseFloat(2.*period).toFixed(9); """, )) period_halve_button = Button(label="/2") period_halve_button.js_on_click( CustomJS( args={'textinput': period_textinput}, code=""" const period = parseFloat(textinput.value); textinput.value = parseFloat(period/2.).toFixed(9); """, )) # a way to select the period period_selection.js_on_click( CustomJS( args={ 'textinput': period_textinput, 'periods': period_list }, code=""" textinput.value = parseFloat(periods[this.active]).toFixed(9); """, )) phase_selection.js_on_click( CustomJS( args={ 'textinput': period_textinput, 'toggle': period_toggle, 'numphases': phase_selection, 'p': period_plot, **period_model_dict, }, code=open( os.path.join(os.path.dirname(__file__), '../static/js/plotjs', 'foldphase.js')).read(), )) # layout period_column = column( period_toggle, period_title, period_textinput, period_selection, row(period_double_button, period_halve_button, width=180), phase_selection, width=180, ) period_layout = column( row(period_plot, period_column), sizing_mode='scale_width', width=width, ) # Period panel p3 = Panel(child=period_layout, title='Period') # tabs for mag, flux, period tabs = Tabs(tabs=[p2, p1, p3], width=width, height=height, sizing_mode='fixed') else: # tabs for mag, flux tabs = Tabs(tabs=[p2, p1], width=width, height=height, sizing_mode='fixed') return bokeh_embed.json_item(tabs)
class ViewerWidgets(object): """ Encapsulates Bokeh widgets, and related callbacks, that are part of prospect's GUI. Except for VI widgets """ def __init__(self, plots, nspec): self.js_files = get_resources('js') self.navigation_button_width = 30 self.z_button_width = 30 self.plot_widget_width = (plots.plot_width+(plots.plot_height//2))//2 - 40 # used for widgets scaling #----- #- Ifiberslider and smoothing widgets # Ifiberslider's value controls which spectrum is displayed # These two widgets call update_plot(), later defined slider_end = nspec-1 if nspec > 1 else 0.5 # Slider cannot have start=end self.ifiberslider = Slider(start=0, end=slider_end, value=0, step=1, title='Spectrum (of '+str(nspec)+')') self.smootherslider = Slider(start=0, end=26, value=0, step=1.0, title='Gaussian Sigma Smooth') self.coaddcam_buttons = None self.model_select = None def add_navigation(self, nspec): #----- #- Navigation buttons self.prev_button = Button(label="<", width=self.navigation_button_width) self.next_button = Button(label=">", width=self.navigation_button_width) self.prev_callback = CustomJS( args=dict(ifiberslider=self.ifiberslider), code=""" if(ifiberslider.value>0 && ifiberslider.end>=1) { ifiberslider.value-- } """) self.next_callback = CustomJS( args=dict(ifiberslider=self.ifiberslider, nspec=nspec), code=""" if(ifiberslider.value<nspec-1 && ifiberslider.end>=1) { ifiberslider.value++ } """) self.prev_button.js_on_event('button_click', self.prev_callback) self.next_button.js_on_event('button_click', self.next_callback) def add_resetrange(self, viewer_cds, plots): #----- #- Axis reset button (superseeds the default bokeh "reset" self.reset_plotrange_button = Button(label="Reset X-Y range", button_type="default") reset_plotrange_code = self.js_files["adapt_plotrange.js"] + self.js_files["reset_plotrange.js"] self.reset_plotrange_callback = CustomJS(args = dict(fig=plots.fig, xmin=plots.xmin, xmax=plots.xmax, spectra=viewer_cds.cds_spectra), code = reset_plotrange_code) self.reset_plotrange_button.js_on_event('button_click', self.reset_plotrange_callback) def add_redshift_widgets(self, z, viewer_cds, plots): ## TODO handle "z" (same issue as viewerplots TBD) #----- #- Redshift / wavelength scale widgets z1 = np.floor(z*100)/100 dz = z-z1 self.zslider = Slider(start=-0.1, end=5.0, value=z1, step=0.01, title='Redshift rough tuning') self.dzslider = Slider(start=0.0, end=0.0099, value=dz, step=0.0001, title='Redshift fine-tuning') self.dzslider.format = "0[.]0000" self.z_input = TextInput(value="{:.4f}".format(z), title="Redshift value:") #- Observer vs. Rest frame wavelengths self.waveframe_buttons = RadioButtonGroup( labels=["Obs", "Rest"], active=0) self.zslider_callback = CustomJS( args=dict(zslider=self.zslider, dzslider=self.dzslider, z_input=self.z_input), code=""" // Protect against 1) recursive call with z_input callback; // 2) out-of-range zslider values (should never happen in principle) var z1 = Math.floor(parseFloat(z_input.value)*100) / 100 if ( (Math.abs(zslider.value-z1) >= 0.01) && (zslider.value >= -0.1) && (zslider.value <= 5.0) ){ var new_z = zslider.value + dzslider.value z_input.value = new_z.toFixed(4) } """) self.dzslider_callback = CustomJS( args=dict(zslider=self.zslider, dzslider=self.dzslider, z_input=self.z_input), code=""" var z = parseFloat(z_input.value) var z1 = Math.floor(z) / 100 var z2 = z-z1 if ( (Math.abs(dzslider.value-z2) >= 0.0001) && (dzslider.value >= 0.0) && (dzslider.value <= 0.0099) ){ var new_z = zslider.value + dzslider.value z_input.value = new_z.toFixed(4) } """) self.zslider.js_on_change('value', self.zslider_callback) self.dzslider.js_on_change('value', self.dzslider_callback) self.z_minus_button = Button(label="<", width=self.z_button_width) self.z_plus_button = Button(label=">", width=self.z_button_width) self.z_minus_callback = CustomJS( args=dict(z_input=self.z_input), code=""" var z = parseFloat(z_input.value) if(z >= -0.09) { z -= 0.01 z_input.value = z.toFixed(4) } """) self.z_plus_callback = CustomJS( args=dict(z_input=self.z_input), code=""" var z = parseFloat(z_input.value) if(z <= 4.99) { z += 0.01 z_input.value = z.toFixed(4) } """) self.z_minus_button.js_on_event('button_click', self.z_minus_callback) self.z_plus_button.js_on_event('button_click', self.z_plus_callback) self.zreset_button = Button(label='Reset to z_pipe') self.zreset_callback = CustomJS( args=dict(z_input=self.z_input, metadata=viewer_cds.cds_metadata, ifiberslider=self.ifiberslider), code=""" var ifiber = ifiberslider.value var z = metadata.data['Z'][ifiber] z_input.value = z.toFixed(4) """) self.zreset_button.js_on_event('button_click', self.zreset_callback) self.z_input_callback = CustomJS( args=dict(spectra = viewer_cds.cds_spectra, coaddcam_spec = viewer_cds.cds_coaddcam_spec, model = viewer_cds.cds_model, othermodel = viewer_cds.cds_othermodel, metadata = viewer_cds.cds_metadata, ifiberslider = self.ifiberslider, zslider = self.zslider, dzslider = self.dzslider, z_input = self.z_input, waveframe_buttons = self.waveframe_buttons, line_data = viewer_cds.cds_spectral_lines, lines = plots.speclines, line_labels = plots.specline_labels, zlines = plots.zoom_speclines, zline_labels = plots.zoom_specline_labels, overlap_waves = plots.overlap_waves, overlap_bands = plots.overlap_bands, fig = plots.fig ), code=""" var z = parseFloat(z_input.value) if ( z >=-0.1 && z <= 5.0 ) { // update zsliders only if needed (avoid recursive call) z_input.value = parseFloat(z_input.value).toFixed(4) var z1 = Math.floor(z*100) / 100 var z2 = z-z1 if ( Math.abs(z1-zslider.value) >= 0.01) zslider.value = parseFloat(parseFloat(z1).toFixed(2)) if ( Math.abs(z2-dzslider.value) >= 0.0001) dzslider.value = parseFloat(parseFloat(z2).toFixed(4)) } else { if (z_input.value < -0.1) z_input.value = (-0.1).toFixed(4) if (z_input.value > 5) z_input.value = (5.0).toFixed(4) } var line_restwave = line_data.data['restwave'] var ifiber = ifiberslider.value var waveshift_lines = (waveframe_buttons.active == 0) ? 1+z : 1 ; var waveshift_spec = (waveframe_buttons.active == 0) ? 1 : 1/(1+z) ; for(var i=0; i<line_restwave.length; i++) { lines[i].location = line_restwave[i] * waveshift_lines line_labels[i].x = line_restwave[i] * waveshift_lines zlines[i].location = line_restwave[i] * waveshift_lines zline_labels[i].x = line_restwave[i] * waveshift_lines } if (overlap_bands.length>0) { for (var i=0; i<overlap_bands.length; i++) { overlap_bands[i].left = overlap_waves[i][0] * waveshift_spec overlap_bands[i].right = overlap_waves[i][1] * waveshift_spec } } function shift_plotwave(cds_spec, waveshift) { var data = cds_spec.data var origwave = data['origwave'] var plotwave = data['plotwave'] if ( plotwave[0] != origwave[0] * waveshift ) { // Avoid redo calculation if not needed for (var j=0; j<plotwave.length; j++) { plotwave[j] = origwave[j] * waveshift ; } cds_spec.change.emit() } } for(var i=0; i<spectra.length; i++) { shift_plotwave(spectra[i], waveshift_spec) } if (coaddcam_spec) shift_plotwave(coaddcam_spec, waveshift_spec) // Update model wavelength array // NEW : don't shift model if othermodel is there if (othermodel) { var zref = othermodel.data['zref'][0] var waveshift_model = (waveframe_buttons.active == 0) ? (1+z)/(1+zref) : 1/(1+zref) ; shift_plotwave(othermodel, waveshift_model) } else if (model) { var zfit = 0.0 if(metadata.data['Z'] !== undefined) { zfit = metadata.data['Z'][ifiber] } var waveshift_model = (waveframe_buttons.active == 0) ? (1+z)/(1+zfit) : 1/(1+zfit) ; shift_plotwave(model, waveshift_model) } """) self.z_input.js_on_change('value', self.z_input_callback) self.waveframe_buttons.js_on_click(self.z_input_callback) self.plotrange_callback = CustomJS( args = dict( z_input=self.z_input, waveframe_buttons=self.waveframe_buttons, fig=plots.fig, ), code=""" var z = parseFloat(z_input.value) // Observer Frame if(waveframe_buttons.active == 0) { fig.x_range.start = fig.x_range.start * (1+z) fig.x_range.end = fig.x_range.end * (1+z) } else { fig.x_range.start = fig.x_range.start / (1+z) fig.x_range.end = fig.x_range.end / (1+z) } """ ) self.waveframe_buttons.js_on_click(self.plotrange_callback) # TODO: for record: is this related to waveframe bug? : 2 callbakcs for same click... def add_oii_widgets(self, plots): #------ #- Zoom on the OII doublet TODO mv js code to other file # TODO: is there another trick than using a cds to pass the "oii_saveinfo" ? # TODO: optimize smoothing for autozoom (current value: 0) cds_oii_saveinfo = ColumnDataSource( {'xmin':[plots.fig.x_range.start], 'xmax':[plots.fig.x_range.end], 'nsmooth':[self.smootherslider.value]}) self.oii_zoom_button = Button(label="OII-zoom", button_type="default") self.oii_zoom_callback = CustomJS( args = dict(z_input=self.z_input, fig=plots.fig, smootherslider=self.smootherslider, cds_oii_saveinfo=cds_oii_saveinfo), code = """ // Save previous setting (for the "Undo" button) cds_oii_saveinfo.data['xmin'] = [fig.x_range.start] cds_oii_saveinfo.data['xmax'] = [fig.x_range.end] cds_oii_saveinfo.data['nsmooth'] = [smootherslider.value] // Center on the middle of the redshifted OII doublet (vaccum) var z = parseFloat(z_input.value) fig.x_range.start = 3728.48 * (1+z) - 100 fig.x_range.end = 3728.48 * (1+z) + 100 // No smoothing (this implies a call to update_plot) smootherslider.value = 0 """) self.oii_zoom_button.js_on_event('button_click', self.oii_zoom_callback) self.oii_undo_button = Button(label="Undo OII-zoom", button_type="default") self.oii_undo_callback = CustomJS( args = dict(fig=plots.fig, smootherslider=self.smootherslider, cds_oii_saveinfo=cds_oii_saveinfo), code = """ fig.x_range.start = cds_oii_saveinfo.data['xmin'][0] fig.x_range.end = cds_oii_saveinfo.data['xmax'][0] smootherslider.value = cds_oii_saveinfo.data['nsmooth'][0] """) self.oii_undo_button.js_on_event('button_click', self.oii_undo_callback) def add_coaddcam(self, plots): #----- #- Highlight individual-arm or camera-coadded spectra coaddcam_labels = ["Camera-coadded", "Single-arm"] self.coaddcam_buttons = RadioButtonGroup(labels=coaddcam_labels, active=0) self.coaddcam_callback = CustomJS( args = dict(coaddcam_buttons = self.coaddcam_buttons, list_lines=[plots.data_lines, plots.noise_lines, plots.zoom_data_lines, plots.zoom_noise_lines], alpha_discrete = plots.alpha_discrete, overlap_bands = plots.overlap_bands, alpha_overlapband = plots.alpha_overlapband), code=""" var n_lines = list_lines[0].length for (var i=0; i<n_lines; i++) { var new_alpha = 1 if (coaddcam_buttons.active == 0 && i<n_lines-1) new_alpha = alpha_discrete if (coaddcam_buttons.active == 1 && i==n_lines-1) new_alpha = alpha_discrete for (var j=0; j<list_lines.length; j++) { list_lines[j][i].glyph.line_alpha = new_alpha } } var new_alpha = 0 if (coaddcam_buttons.active == 0) new_alpha = alpha_overlapband for (var j=0; j<overlap_bands.length; j++) { overlap_bands[j].fill_alpha = new_alpha } """ ) self.coaddcam_buttons.js_on_click(self.coaddcam_callback) def add_metadata_tables(self, viewer_cds, show_zcat=True, template_dicts=None, top_metadata=['TARGETID', 'EXPID']): """ Display object-related informations top_metadata: metadata to be highlighted in table_a Note: "short" CDS, with a single row, are used to fill these bokeh tables. When changing object, js code modifies these short CDS so that tables are updated. """ #- Sorted list of potential metadata: metadata_to_check = ['TARGETID', 'HPXPIXEL', 'TILEID', 'COADD_NUMEXP', 'COADD_EXPTIME', 'NIGHT', 'EXPID', 'FIBER', 'CAMERA', 'MORPHTYPE'] metadata_to_check += [ ('mag_'+x) for x in viewer_cds.phot_bands ] table_keys = [] for key in metadata_to_check: if key in viewer_cds.cds_metadata.data.keys(): table_keys.append(key) if 'NUM_'+key in viewer_cds.cds_metadata.data.keys(): for prefix in ['FIRST','LAST','NUM']: table_keys.append(prefix+'_'+key) if key in top_metadata: top_metadata.append(prefix+'_'+key) #- Table a: "top metadata" table_a_keys = [ x for x in table_keys if x in top_metadata ] self.shortcds_table_a, self.table_a = _metadata_table(table_a_keys, viewer_cds, table_width=600, shortcds_name='shortcds_table_a', selectable=True) #- Table b: Targeting information self.shortcds_table_b, self.table_b = _metadata_table(['Targeting masks'], viewer_cds, table_width=self.plot_widget_width, shortcds_name='shortcds_table_b', selectable=True) #- Table(s) c/d : Other information (imaging, etc.) remaining_keys = [ x for x in table_keys if x not in top_metadata ] if len(remaining_keys) > 7: table_c_keys = remaining_keys[0:len(remaining_keys)//2] table_d_keys = remaining_keys[len(remaining_keys)//2:] else: table_c_keys = remaining_keys table_d_keys = None self.shortcds_table_c, self.table_c = _metadata_table(table_c_keys, viewer_cds, table_width=self.plot_widget_width, shortcds_name='shortcds_table_c', selectable=False) if table_d_keys is None: self.shortcds_table_d, self.table_d = None, None else: self.shortcds_table_d, self.table_d = _metadata_table(table_d_keys, viewer_cds, table_width=self.plot_widget_width, shortcds_name='shortcds_table_d', selectable=False) #- Table z: redshift fitting information if show_zcat is not None : if template_dicts is not None : # Add other best fits fit_results = template_dicts[1] # Case of DeltaChi2 : compute it from Chi2s # The "DeltaChi2" in rr fits is between best fits for a given (spectype,subtype) # Convention: DeltaChi2 = -1 for the last fit. chi2s = fit_results['CHI2'][0] full_deltachi2s = np.zeros(len(chi2s))-1 full_deltachi2s[:-1] = chi2s[1:]-chi2s[:-1] cdsdata = dict(Nfit = np.arange(1,len(chi2s)+1), SPECTYPE = fit_results['SPECTYPE'][0], # [0:num_best_fits] (if we want to restrict... TODO?) SUBTYPE = fit_results['SUBTYPE'][0], Z = [ "{:.4f}".format(x) for x in fit_results['Z'][0] ], ZERR = [ "{:.4f}".format(x) for x in fit_results['ZERR'][0] ], ZWARN = fit_results['ZWARN'][0], CHI2 = [ "{:.1f}".format(x) for x in fit_results['CHI2'][0] ], DELTACHI2 = [ "{:.1f}".format(x) for x in full_deltachi2s ]) self.shortcds_table_z = ColumnDataSource(cdsdata, name='shortcds_table_z') columns_table_z = [ TableColumn(field=x, title=t, width=w) for x,t,w in [ ('Nfit','Nfit',5), ('SPECTYPE','SPECTYPE',70), ('SUBTYPE','SUBTYPE',60), ('Z','Z',50) , ('ZERR','ZERR',50), ('ZWARN','ZWARN',50), ('DELTACHI2','Δχ2(N+1/N)',70)] ] self.table_z = DataTable(source=self.shortcds_table_z, columns=columns_table_z, selectable=False, index_position=None, width=self.plot_widget_width) self.table_z.height = 3 * self.table_z.row_height else : self.shortcds_table_z, self.table_z = _metadata_table(viewer_cds.zcat_keys, viewer_cds, table_width=self.plot_widget_width, shortcds_name='shortcds_table_z', selectable=False) else : self.table_z = Div(text="Not available ") self.shortcds_table_z = None def add_specline_toggles(self, viewer_cds, plots): #----- #- Toggle lines self.speclines_button_group = CheckboxButtonGroup( labels=["Emission lines", "Absorption lines"], active=[]) self.majorline_checkbox = CheckboxGroup( labels=['Show only major lines'], active=[]) self.speclines_callback = CustomJS( args = dict(line_data = viewer_cds.cds_spectral_lines, lines = plots.speclines, line_labels = plots.specline_labels, zlines = plots.zoom_speclines, zline_labels = plots.zoom_specline_labels, lines_button_group = self.speclines_button_group, majorline_checkbox = self.majorline_checkbox), code=""" var show_emission = false var show_absorption = false if (lines_button_group.active.indexOf(0) >= 0) { // index 0=Emission in active list show_emission = true } if (lines_button_group.active.indexOf(1) >= 0) { // index 1=Absorption in active list show_absorption = true } for(var i=0; i<lines.length; i++) { if ( !(line_data.data['major'][i]) && (majorline_checkbox.active.indexOf(0)>=0) ) { lines[i].visible = false line_labels[i].visible = false zlines[i].visible = false zline_labels[i].visible = false } else if (line_data.data['emission'][i]) { lines[i].visible = show_emission line_labels[i].visible = show_emission zlines[i].visible = show_emission zline_labels[i].visible = show_emission } else { lines[i].visible = show_absorption line_labels[i].visible = show_absorption zlines[i].visible = show_absorption zline_labels[i].visible = show_absorption } } """ ) self.speclines_button_group.js_on_click(self.speclines_callback) self.majorline_checkbox.js_on_click(self.speclines_callback) def add_model_select(self, viewer_cds, template_dicts, num_approx_fits, with_full_2ndfit=True): #------ #- Select secondary model to display model_options = ['Best fit', '2nd best fit'] for i in range(1,1+num_approx_fits) : ith = 'th' if i==1 : ith='st' if i==2 : ith='nd' if i==3 : ith='rd' model_options.append(str(i)+ith+' fit (approx)') if with_full_2ndfit is False : model_options.remove('2nd best fit') for std_template in ['QSO', 'GALAXY', 'STAR'] : model_options.append('STD '+std_template) self.model_select = Select(value=model_options[0], title="Other model (dashed curve):", options=model_options) model_select_code = self.js_files["interp_grid.js"] + self.js_files["smooth_data.js"] + self.js_files["select_model.js"] self.model_select_callback = CustomJS( args = dict(ifiberslider = self.ifiberslider, model_select = self.model_select, fit_templates=template_dicts[0], cds_othermodel = viewer_cds.cds_othermodel, cds_model_2ndfit = viewer_cds.cds_model_2ndfit, cds_model = viewer_cds.cds_model, fit_results=template_dicts[1], std_templates=template_dicts[2], median_spectra = viewer_cds.cds_median_spectra, smootherslider = self.smootherslider, z_input = self.z_input, cds_metadata = viewer_cds.cds_metadata), code = model_select_code) self.model_select.js_on_change('value', self.model_select_callback) def add_update_plot_callback(self, viewer_cds, plots, vi_widgets, template_dicts): #----- #- Main js code to update plots update_plot_code = (self.js_files["adapt_plotrange.js"] + self.js_files["interp_grid.js"] + self.js_files["smooth_data.js"] + self.js_files["coadd_brz_cameras.js"] + self.js_files["update_plot.js"]) # TMP handling of template_dicts the_fit_results = None if template_dicts is None else template_dicts[1] # dirty self.update_plot_callback = CustomJS( args = dict( spectra = viewer_cds.cds_spectra, coaddcam_spec = viewer_cds.cds_coaddcam_spec, model = viewer_cds.cds_model, othermodel = viewer_cds.cds_othermodel, model_2ndfit = viewer_cds.cds_model_2ndfit, metadata = viewer_cds.cds_metadata, fit_results = the_fit_results, shortcds_table_z = self.shortcds_table_z, shortcds_table_a = self.shortcds_table_a, shortcds_table_b = self.shortcds_table_b, shortcds_table_c = self.shortcds_table_c, shortcds_table_d = self.shortcds_table_d, ifiberslider = self.ifiberslider, smootherslider = self.smootherslider, z_input = self.z_input, fig = plots.fig, imfig_source = plots.imfig_source, imfig_urls = plots.imfig_urls, model_select = self.model_select, vi_comment_input = vi_widgets.vi_comment_input, vi_std_comment_select = vi_widgets.vi_std_comment_select, vi_name_input = vi_widgets.vi_name_input, vi_quality_input = vi_widgets.vi_quality_input, vi_quality_labels = vi_widgets.vi_quality_labels, vi_issue_input = vi_widgets.vi_issue_input, vi_z_input = vi_widgets.vi_z_input, vi_category_select = vi_widgets.vi_category_select, vi_issue_slabels = vi_widgets.vi_issue_slabels ), code = update_plot_code ) self.smootherslider.js_on_change('value', self.update_plot_callback) self.ifiberslider.js_on_change('value', self.update_plot_callback)