Exemplo n.º 1
0
legend_labels = ['' for _ in range(num_teams)]

sc_renderers = plot_sc_data(team_objs, sc_sources, line_colors)

compile_expected_wins(league_obj, team_objs, weeks, owner_to_idx, num_teams)

ew_sources = get_ew_sources(weeks, team_objs, owners, week_num, num_teams)

expected_wins_table = initialize_ew_table(team_objs, week_num, num_teams)
table_wrap = column(children=[expected_wins_table])

ew_renderers = plot_ew_data(team_objs, ew_sources, line_colors)

# register callback handlers to respond to changes in widget values
lg_id_input.on_change('value', league_id_handler)
lg_id_input.js_on_change('value', ga_view_callback)
week_slider.on_change('value', week_slider_handler)
team1_dd.on_change('value', team1_select_handler)
team2_dd.on_change('value', team2_select_handler)
comp_button.on_click(helper_handler)
year_input.on_change('value', season_handler)

# arrange layout
tab1 = Panel(child=plot1_wrap, title='Scores')
tab2 = Panel(child=plot2_wrap, title='Expected Wins')
tab3 = Panel(child=table_wrap, title='Summary')

figures = Tabs(tabs=[tab1, tab2, tab3], width=500)

compare_widgets = column(team1_dd, team2_dd, comp_button)
Exemplo n.º 2
0
Arquivo: navia.py Projeto: d-que/navia
mass_finder_header = Div(text= " <h2>Mass Finder</h2>", height=45, width=400 )
# mass_finder_range_text = Div(text= " Range mz:", width= 150, height=30 )
mass_finder_range_slider = RangeSlider(start=1.0, end=500.0, value=(1.0,50.0), title='Charge range:',name='mass_finder_range_slider', step=1, width= 250, height=30)
# mass_finder_mass_text = Div(text= " Mass of Complex (kDa):", width= 150, height=30 )
mass_finder_mass = Slider(value=100, start=0.0, end=1000.0, step=10.0, title='Mass of Complex (kDa)',name='gau_sigma', width=250, height=30)

mass_finder_exact_mass_text = Div(text= "Enter exact Mass (Da)", width= 150, height=30 )
mass_finder_exact_mass_sele = TextInput(value=str(mass_finder_mass.value*1000), disabled=False, width=100, height=30)

mass_finder_line_text = Div(text= "Show mz prediction", width= 150, height=30 )
mass_finder_line_sele = Toggle(label='off', active=False, width=100, height=30, callback=toggle_cb)

mass_finder_cb =CustomJS(args=dict(mass_finder_line_sele=mass_finder_line_sele, raw_mz=raw_mz, mass_finder_data=mass_finder_data, mass_finder_exact_mass_sele=mass_finder_exact_mass_sele, mass_finder_mass=mass_finder_mass, mass_finder_range_slider=mass_finder_range_slider, mfl=mfl), code=open(os.path.join(os.getcwd(), 'JS_Functions', "mass_finder_cb.js")).read())
mass_finder_exact_cb =CustomJS(args=dict(mass_finder_line_sele=mass_finder_line_sele, mass_finder_exact_mass_sele=mass_finder_exact_mass_sele, mass_finder_mass=mass_finder_mass), code=open(os.path.join(os.getcwd(), 'JS_Functions', "mass_finder_exact_cb.js")).read())
mass_finder_exact_mass_sele.js_on_change('value', mass_finder_exact_cb)

mass_finder_column=Column(mass_finder_header,mass_finder_mass, mass_finder_range_slider, Row(mass_finder_exact_mass_text,mass_finder_exact_mass_sele), Row(mass_finder_line_text, mass_finder_line_sele), visible=False)
mass_finder.js_link('active', mass_finder_column, 'visible')
mass_finder_line_sele.js_link('active', mfl, 'visible')
mass_finder_mass.js_on_change('value', mass_finder_cb)
mass_finder_line_sele.js_on_change('active', mass_finder_cb)
mass_finder_range_slider.js_on_change('value',mass_finder_cb)
### DATA PROCESSING ###

cropping = Div(text= " Range mz:", width= 150, height=30 )
# crop_max = Div(text= " ", width= 150, height=30 )
gau_name = Div(text= " Gaussian Smoothing:", width= 150, height=30 )
n_smooth_name = Div(text= " Repeats of Smoothing:", width= 150, height=30 )
# bin_name = Div(text= " Bin Every:", width= 150, height=30 )
int_name = Div(text= " Intensity Threshold (%)", width= 150, height=30 )
Exemplo n.º 3
0
def spectroscopy_plot(obj_id, spec_id=None):
    obj = Obj.query.get(obj_id)
    spectra = Obj.query.get(obj_id).spectra
    if spec_id is not None:
        spectra = [spec for spec in spectra if spec.id == int(spec_id)]
    if len(spectra) == 0:
        return None, None, None

    color_map = dict(zip([s.id for s in spectra], viridis(len(spectra))))

    data = []
    for i, s in enumerate(spectra):

        # normalize spectra to a common average flux per resolving
        # element of 1 (facilitates easy visual comparison)
        normfac = np.sum(np.gradient(s.wavelengths) * s.fluxes) / len(s.fluxes)

        if not (np.isfinite(normfac) and normfac > 0):
            # otherwise normalize the value at the median wavelength to 1
            median_wave_index = np.argmin(
                np.abs(s.wavelengths - np.median(s.wavelengths)))
            normfac = s.fluxes[median_wave_index]

        df = pd.DataFrame({
            'wavelength':
            s.wavelengths,
            'flux':
            s.fluxes / normfac,
            'id':
            s.id,
            'telescope':
            s.instrument.telescope.name,
            'instrument':
            s.instrument.name,
            'date_observed':
            s.observed_at.date().isoformat(),
            'pi':
            s.assignment.run.pi if s.assignment is not None else "",
        })
        data.append(df)
    data = pd.concat(data)

    dfs = []
    for i, s in enumerate(spectra):
        # Smooth the spectrum by using a rolling average
        df = (pd.DataFrame({
            'wavelength': s.wavelengths,
            'flux': s.fluxes
        }).rolling(2).mean(numeric_only=True).dropna())
        dfs.append(df)

    smoothed_data = pd.concat(dfs)

    split = data.groupby('id')
    hover = HoverTool(tooltips=[
        ('wavelength', '$x'),
        ('flux', '$y'),
        ('telesecope', '@telescope'),
        ('instrument', '@instrument'),
        ('UTC date observed', '@date_observed'),
        ('PI', '@pi'),
    ])
    smoothed_max = np.max(smoothed_data['flux'])
    smoothed_min = np.min(smoothed_data['flux'])
    ymax = smoothed_max * 1.05
    ymin = smoothed_min - 0.05 * (smoothed_max - smoothed_min)
    xmin = np.min(data['wavelength']) - 100
    xmax = np.max(data['wavelength']) + 100
    plot = figure(
        plot_width=600,
        plot_height=300,
        y_range=(ymin, ymax),
        x_range=(xmin, xmax),
        sizing_mode='scale_both',
        tools='box_zoom,wheel_zoom,pan,reset',
        active_drag='box_zoom',
    )
    plot.add_tools(hover)
    model_dict = {}
    for i, (key, df) in enumerate(split):
        model_dict['s' + str(i)] = plot.line(x='wavelength',
                                             y='flux',
                                             color=color_map[key],
                                             source=ColumnDataSource(df))
    plot.xaxis.axis_label = 'Wavelength (Å)'
    plot.yaxis.axis_label = 'Flux'
    plot.toolbar.logo = None

    # TODO how to choose a good default?
    plot.y_range = Range1d(0, 1.03 * data.flux.max())

    toggle = CheckboxWithLegendGroup(
        labels=[
            f'{s.instrument.telescope.nickname}/{s.instrument.name} ({s.observed_at.date().isoformat()})'
            for s in spectra
        ],
        active=list(range(len(spectra))),
        colors=[color_map[k] for k, df in split],
    )
    toggle.callback = CustomJS(
        args={
            'toggle': toggle,
            **model_dict
        },
        code="""
          for (let i = 0; i < toggle.labels.length; i++) {
              eval("s" + i).visible = (toggle.active.includes(i))
          }
    """,
    )

    z_title = Div(text="Redshift (<i>z</i>): ")
    z_slider = Slider(
        value=obj.redshift if obj.redshift is not None else 0.0,
        start=0.0,
        end=1.0,
        step=0.001,
        show_value=False,
        format="0[.]000",
    )
    z_textinput = TextInput(
        value=str(obj.redshift if obj.redshift is not None else 0.0))
    z_slider.callback = CustomJS(
        args={
            'slider': z_slider,
            'textinput': z_textinput
        },
        code="""
            textinput.value = slider.value.toFixed(3).toString();
            textinput.change.emit();
        """,
    )
    z = column(z_title, z_slider, z_textinput)

    v_title = Div(text="<i>V</i><sub>expansion</sub> (km/s): ")
    v_exp_slider = Slider(
        value=0.0,
        start=0.0,
        end=3e4,
        step=10.0,
        show_value=False,
    )
    v_exp_textinput = TextInput(value='0')
    v_exp_slider.callback = CustomJS(
        args={
            'slider': v_exp_slider,
            'textinput': v_exp_textinput
        },
        code="""
            textinput.value = slider.value.toFixed(0).toString();
            textinput.change.emit();
        """,
    )
    v_exp = column(v_title, v_exp_slider, v_exp_textinput)

    for i, (wavelengths, color) in enumerate(SPEC_LINES.values()):
        el_data = pd.DataFrame({'wavelength': wavelengths})
        obj_redshift = 0 if obj.redshift is None else obj.redshift
        el_data['x'] = el_data['wavelength'] * (1.0 + obj_redshift)
        model_dict[f'el{i}'] = plot.segment(
            x0='x',
            x1='x',
            # TODO change limits
            y0=0,
            y1=1e-13,
            color=color,
            source=ColumnDataSource(el_data),
        )
        model_dict[f'el{i}'].visible = False

    # Split spectral lines into 3 columns
    element_dicts = np.array_split(
        np.array(list(SPEC_LINES.items()), dtype=object), 3)
    elements_groups = []
    col_offset = 0
    for element_dict in element_dicts:
        labels = [key for key, value in element_dict]
        colors = [c for key, (w, c) in element_dict]
        elements = CheckboxWithLegendGroup(
            labels=labels,
            active=[],
            colors=colors,
        )
        elements_groups.append(elements)

        # TODO callback policy: don't require submit for text changes?
        elements.callback = CustomJS(
            args={
                'elements': elements,
                'z': z_textinput,
                'v_exp': v_exp_textinput,
                **model_dict,
            },
            code=f"""
            let c = 299792.458; // speed of light in km / s
            const i_max = {col_offset} + elements.labels.length;
            let local_i = 0;
            for (let i = {col_offset}; i < i_max; i++) {{
                let el = eval("el" + i);
                el.visible = (elements.active.includes(local_i))
                el.data_source.data.x = el.data_source.data.wavelength.map(
                    x_i => (x_i * (1 + parseFloat(z.value)) /
                                    (1 + parseFloat(v_exp.value) / c))
                );
                el.data_source.change.emit();
                local_i++;
            }}
        """,
        )

        col_offset += len(labels)

    # Our current version of Bokeh doesn't properly execute multiple callbacks
    # https://github.com/bokeh/bokeh/issues/6508
    # Workaround is to manually put the code snippets together
    z_textinput.js_on_change(
        'value',
        CustomJS(
            args={
                'elements0': elements_groups[0],
                'elements1': elements_groups[1],
                'elements2': elements_groups[2],
                'z': z_textinput,
                'slider': z_slider,
                'v_exp': v_exp_textinput,
                **model_dict,
            },
            code="""
            // Update slider value to match text input
            slider.value = parseFloat(z.value).toFixed(3);

            // Update plot data for each element
            let c = 299792.458; // speed of light in km / s
            const offset_col_1 = elements0.labels.length;
            const offset_col_2 = offset_col_1 + elements1.labels.length;
            const i_max = offset_col_2 + elements2.labels.length;
            for (let i = 0; i < i_max; i++) {{
                let el = eval("el" + i);
                el.visible =
                    elements0.active.includes(i) ||
                    elements1.active.includes(i - offset_col_1) ||
                    elements2.active.includes(i - offset_col_2);
                el.data_source.data.x = el.data_source.data.wavelength.map(
                    x_i => (x_i * (1 + parseFloat(z.value)) /
                                    (1 + parseFloat(v_exp.value) / c))
                );
                el.data_source.change.emit();
            }}
        """,
        ),
    )

    v_exp_textinput.js_on_change(
        'value',
        CustomJS(
            args={
                'elements0': elements_groups[0],
                'elements1': elements_groups[1],
                'elements2': elements_groups[2],
                'z': z_textinput,
                'slider': v_exp_slider,
                'v_exp': v_exp_textinput,
                **model_dict,
            },
            code="""
            // Update slider value to match text input
            slider.value = parseFloat(v_exp.value).toFixed(3);

            // Update plot data for each element
            let c = 299792.458; // speed of light in km / s
            const offset_col_1 = elements0.labels.length;
            const offset_col_2 = offset_col_1 + elements1.labels.length;
            const i_max = offset_col_2 + elements2.labels.length;
            for (let i = 0; i < i_max; i++) {{
                let el = eval("el" + i);
                el.visible =
                    elements0.active.includes(i) ||
                    elements1.active.includes(i - offset_col_1) ||
                    elements2.active.includes(i - offset_col_2);
                el.data_source.data.x = el.data_source.data.wavelength.map(
                    x_i => (x_i * (1 + parseFloat(z.value)) /
                                    (1 + parseFloat(v_exp.value) / c))
                );
                el.data_source.change.emit();
            }}
        """,
        ),
    )

    row1 = row(plot, toggle)
    row2 = row(elements_groups)
    row3 = row(z, v_exp)
    layout = column(row1, row2, row3)
    return _plot_to_json(layout)
Exemplo n.º 4
0
def compare():
    # if proj among arguments, show this tree first.
    try:
        proj_a = int(request.args.get('proj_a', None))
        proj_b = int(request.args.get('proj_b', None))
    except (TypeError, ValueError):
        proj_a = proj_b = None
    include = request.args.get('include', None)

    # list of projects (and proj_names) used to create dropdown project selector
    upload_list = UserFile.query.filter_by(user_id=current_user.id).\
        filter_by(run_complete=True).order_by(UserFile.file_id).all()

    if len(upload_list) > 1:
        # Use specified project from args or highest file_id as CURRENT PROJECT
        current_proj = upload_list[-1]  # override if valid proj specified
        if proj_a and proj_b:
            current_temp_a = [u for u in upload_list if u.file_id == proj_a]
            current_temp_b = [u for u in upload_list if u.file_id == proj_b]
            # if not among user's finished projects, use highest file_id
            if len(current_temp_a) == 1 and len(current_temp_b) == 1:
                current_proj_a = current_temp_a[0]
                current_proj_b = current_temp_b[0]
            else:
                current_proj_a = upload_list[-2]
                current_proj_b = upload_list[-1]
        else:
            current_proj_a = upload_list[-2]
            current_proj_b = upload_list[-1]
        detail_path1 = naming_rules.get_detailed_path(current_proj_a)
        detail_path2 = naming_rules.get_detailed_path(current_proj_b)
        js_name1 = naming_rules.get_js_name(current_proj_a)
        js_name2 = naming_rules.get_js_name(current_proj_b)
        xlabel = u"Effect size ({})".format(current_proj_a.get_fancy_filename())
        ylabel = u"Effect size ({})".format(current_proj_b.get_fancy_filename())

        # load pathways with 1+ mutation in 1+ patients,
        # ignoring ones with 'cancer' etc in name
        all_paths1 = load_pathway_list_from_file(detail_path1)
        all_paths2 = load_pathway_list_from_file(detail_path2)

        # IDs with p<0.05 and +ve effect
        sig_p = OrderedDict(
            [(i.path_id, i.nice_name) for i in all_paths1 if i.gene_set])
        sig_pids1 = [i for i in sig_p]
        sig_p2 = OrderedDict(
            [(i.path_id, i.nice_name) for i in all_paths2 if i.gene_set])
        sig_pids2 = [i for i in sig_p2]
        sig_p.update(sig_p2)  # ORDERED by proj1 effect size

        # BUILD DATAFRAME WITH ALL sig PATHWAYS, proj1 object order.
        pway_names = sig_p.values()  # order important
        columns = ['path_id', 'pname', 'ind1', 'ind2', 'e1', 'e2', 'e1_only',
                   'e2_only', 'q1', 'q2']
        df = pd.DataFrame(index=sig_p.keys(), data={'pname': pway_names},
                          columns=columns)
        for path_group, evar, qvar, ind, sigs in \
                [(all_paths1, 'e1', 'q1', 'ind1', sig_pids1),
                 (all_paths2, 'e2', 'q2', 'ind2', sig_pids2)]:
            for path in path_group:
                path_id = path.path_id
                if path_id not in sig_p:
                    continue
                df.loc[path_id, evar] = get_effect(path)
                df.loc[path_id, qvar] = get_q(path)
                temp_ind = sigs.index(path_id) if path_id in sigs else -1
                df.loc[path_id, ind] = temp_ind
        df.ind1.fillna(-1, inplace=True)
        df.ind2.fillna(-1, inplace=True)
        df.e1_only = df.where(df.e2.isnull())['e1']
        df.e2_only = df.where(df.e1.isnull())['e2']

        inds1 = list(df.ind1)
        inds2 = list(df.ind2)

        source = ColumnDataSource(data=df)
        source_full = ColumnDataSource(data=df)
        source.name, source_full.name = 'data_visible', 'data_full'
        # SET UP FIGURE
        minx = df.e1.min()
        minx *= 1 - minx / abs(minx) * 0.2
        miny = df.e2.min()
        miny *= 1 - miny/abs(miny) * 0.2
        maxx = df.e1.max() * 1.2
        maxy = df.e2.max() * 1.2
        TOOLS = "lasso_select,box_select,hover,crosshair,pan,wheel_zoom,"\
                "box_zoom,reset,tap,help" # poly_select,lasso_select, previewsave

        # SUPLOTS
        p = figure(plot_width=DIM_COMP_W, plot_height=DIM_COMP_H, tools=TOOLS,
                   title=None, logo=None, toolbar_location="above",
                   x_range=Range1d(minx, maxx), y_range=Range1d(miny, maxy),
                   x_axis_type="log", y_axis_type="log"
                   )
        pb = figure(plot_width=DIM_COMP_SM, plot_height=DIM_COMP_H, tools=TOOLS,
                    y_range=p.y_range, x_axis_type="log", y_axis_type="log")
        pa = figure(plot_width=DIM_COMP_W, plot_height=DIM_COMP_SM, tools=TOOLS,
                    x_range=p.x_range, x_axis_type="log", y_axis_type="log")
        pp = figure(plot_width=DIM_COMP_SM, plot_height=DIM_COMP_SM,
                    tools=TOOLS, outline_line_color=None)

        # SPANS
        p.add_layout(plot_fns.get_span(1, 'height'))
        p.add_layout(plot_fns.get_span(1, 'width'))
        pa.add_layout(plot_fns.get_span(1, 'height'))
        pb.add_layout(plot_fns.get_span(1, 'width'))

        # STYLE
        for ax in [p, pa, pb]:
            ax.grid.visible = False
            ax.outline_line_width = 2
            ax.background_fill_color = 'whitesmoke'
        for ax in [pa, pb]:
            ax.xaxis.visible = False
            ax.yaxis.visible = False

        pa.title.text = xlabel
        pb.title.text = ylabel
        pa.title_location, pa.title.align = 'below', 'center'
        pb.title_location, pb.title.align = 'left', 'center'

        # WIDGETS
        q_input = TextInput(value='', title="P* cutoff",
                            placeholder='e.g. 0.05')
        gene_input = TextInput(value='', title="Gene list",
                               placeholder='e.g. TP53,BRAF')
        radio_include = RadioGroup(labels=["Include", "Exclude"], active=0)
        widgets = widgetbox(q_input, gene_input, radio_include, width=200,
                            css_classes=['widgets_sg'])

        grid = gridplot([[pb, p, widgets],
                         [Spacer(width=DIM_COMP_SM), pa, Spacer()]],
                        sizing_mode='fixed')

        cb_inclusion = CustomJS(args=dict(genes=gene_input), code="""
            var gene_str = genes.value
            if (!gene_str)
                return;
            var include = cb_obj.active == 0 ? true : false
            selectPathwaysByGenes(gene_str, include);
            """)
        cb_genes = CustomJS(args=dict(radio=radio_include), code="""
            var gene_str = cb_obj.value
            if (!gene_str)
                return;
            var include = radio.active == 0 ? true : false
            selectPathwaysByGenes(gene_str, include);
            """)
        radio_include.js_on_change('active', cb_inclusion)
        gene_input.js_on_change('value', cb_genes)

        # SCATTER
        p.circle("e1", "e2", source=source, **SCATTER_KW)
        pa.circle('e1_only', 1, source=source, **SCATTER_KW)
        pb.circle(1, 'e2_only', source=source, **SCATTER_KW)

        # HOVER
        for hover in grid.select(dict(type=HoverTool)):
            hover.tooltips = OrderedDict([
                ("name", "@pname"),
                ("effects", "(@e1, @e2)"),
                ("P*", ("(@q1, @q2)"))
            ])

        # ADD Q FILTERING CALLBACK
        callback = CustomJS(args=dict(source=source, full=source_full), code="""
            // get old selection indices, if any
            var prv_selected = source.selected['1d'].indices;
            var prv_select_full = []
            for(var i=0; i<prv_selected.length; i++){
                prv_select_full.push(scatter_array[prv_selected[i]])
            }
            var new_selected = []
            var q_val = cb_obj.value;
            if(q_val == '')
                q_val = 1
            var fullset = full.data;
            var n_total = fullset['e1'].length;
            // Convert float64arrays to array
            var col_names = %s ;
            col_names.forEach(function(col_name){
                source.data[col_name] = [].slice.call(source.data[col_name])
                source.data[col_name].length = 0
            })
            scatter_array.length = 0;
            var j = -1;  // new glyph indices
            for (i = 0; i < n_total; i++) {
                this_q1 = fullset['q1'][i];
                this_q2 = fullset['q2'][i];
                if(this_q1 <= q_val || this_q2 <= q_val){
                    j++; // preserve previous selection if still visible
                    col_names.forEach(function(col){
                        source.data[col].push(fullset[col][i]);
                    })
                    scatter_array.push(i)
                    if($.inArray(i, prv_select_full) > -1){
                        new_selected.push(j);
                    }
                }
            }
            source.selected['1d'].indices = new_selected;
            source.trigger('change');
            updateIfSelectionChange_afterWait();
            """ % columns)

        q_input.js_on_change('value', callback)
        script, div = plot_fns.get_bokeh_components(grid)

        proj_dir_a = naming_rules.get_project_folder(current_proj_a)
        proj_dir_b = naming_rules.get_project_folder(current_proj_b)
        if os.path.exists(os.path.join(proj_dir_a, 'matrix_svg_cnv')) and \
                os.path.exists(os.path.join(proj_dir_b, 'matrix_svg_cnv')):
            has_cnv = True
        else:
            has_cnv = False

    else:  # not enough projects yet!
        flash("Two completed projects are required for a comparison.",
              "warning")
        return redirect(url_for('.index'))

    return render_template('pway/compare.html',
                           current_projs=[current_proj_a, current_proj_b],
                           inds_use=[inds1, inds2],
                           has_cnv=has_cnv,
                           js_name_a=js_name1,
                           js_name_b=js_name2,
                           projects=upload_list,
                           bokeh_script=script,
                           bokeh_div=div, include_genes=include,
                           resources=plot_fns.resources)
Exemplo n.º 5
0
def volcano(data,
            folder='',
            tohighlight=None,
            tooltips=[('gene', '@gene_id')],
            title="volcano plot",
            xlabel='log-fold change',
            ylabel='-log(Q)',
            maxvalue=100,
            searchbox=False,
            logfoldtohighlight=0.15,
            pvaltohighlight=0.1,
            showlabels=False):
    """
    Make an interactive volcano plot from Differential Expression analysis tools outputs

    Args:
    -----
      data: a df with rows genes and cols [log2FoldChange, pvalue, gene_id]
      folder: str of location where to save the plot, won't save if empty
      tohighlight: list[str] of genes to highlight in the plot
      tooltips: list[tuples(str,str)] if user wants tot specify another bokeh tooltip
      title: str plot title
      xlabel: str if user wants to specify the title of the x axis
      ylabel: str if user wants tot specify the title of the y axis
      maxvalue: float the max -log2(pvalue authorized usefull when managing inf vals)
      searchbox: bool whether or not to add a searchBox to interactively highlight genes
      logfoldtohighlight: float min logfoldchange when to diplay points
      pvaltohighlight: float min pvalue when to diplay points
      showlabels: bool whether or not to show a text above each datapoint with its label information

    Returns:
    --------
      The bokeh object
    """
    # pdb.set_trace()
    to_plot_not, to_plot_yes = selector(
        data, tohighlight if tohighlight is not None else [],
        logfoldtohighlight, pvaltohighlight)
    hover = bokeh.models.HoverTool(tooltips=tooltips, names=['circles'])

    # Create figure
    p = bokeh.plotting.figure(title=title, plot_width=650, plot_height=450)

    p.xgrid.grid_line_color = 'white'
    p.ygrid.grid_line_color = 'white'
    p.xaxis.axis_label = xlabel
    p.yaxis.axis_label = ylabel

    # Add the hover tool
    p.add_tools(hover)
    p, source1 = add_points(p,
                            to_plot_not,
                            'log2FoldChange',
                            'pvalue',
                            color='#1a9641',
                            maxvalue=maxvalue)
    p, source2 = add_points(p,
                            to_plot_yes,
                            'log2FoldChange',
                            'pvalue',
                            color='#fc8d59',
                            alpha=0.6,
                            outline=True,
                            maxvalue=maxvalue)
    if showlabels:
        labels = LabelSet(x='log2FoldChange',
                          y='transformed_q',
                          text_font_size='7pt',
                          text="gene_id",
                          level="glyph",
                          x_offset=5,
                          y_offset=5,
                          source=source2,
                          render_mode='canvas')
        p.add_layout(labels)
    if searchbox:
        text = TextInput(title="text", value="gene")
        text.js_on_change(
            'value',
            CustomJS(args=dict(source=source1),
                     code="""
      var data = source.data
      var value = cb_obj.value
      var gene_id = data.gene_id
      var a = -1
      for (i=0; i < gene_id.length; i++) {
          if ( gene_id[i]===value ) { a=i; console.log(i); data.size[i]=7; data.alpha[i]=1; data.color[i]='#fc8d59' }
      }
      source.data = data
      console.log(source)
      console.log(cb_obj)
      source.change.emit()
      console.log(source)
      """))
        p = column(text, p)
    p.output_backend = "svg"
    if folder:
        save(p, folder + title.replace(' ', "_") + "_volcano.html")
        export_svg(p,
                   filename=folder + title.replace(' ', "_") + "_volcano.svg")
    try:
        show(p)
    except:
        show(p)
    return p
Exemplo n.º 6
0
def spectroscopy_plot(obj_id, user, spec_id=None, width=600, height=300):
    obj = Obj.query.get(obj_id)
    spectra = (
        DBSession().query(Spectrum).join(Obj).join(GroupSpectrum).filter(
            Spectrum.obj_id == obj_id,
            GroupSpectrum.group_id.in_([g.id for g in user.accessible_groups]),
        )).all()

    if spec_id is not None:
        spectra = [spec for spec in spectra if spec.id == int(spec_id)]
    if len(spectra) == 0:
        return None, None, None

    rainbow = cm.get_cmap('rainbow', len(spectra))
    palette = list(map(rgb2hex, rainbow(range(len(spectra)))))
    color_map = dict(zip([s.id for s in spectra], palette))

    data = []
    for i, s in enumerate(spectra):

        # normalize spectra to a median flux of 1 for easy comparison
        normfac = np.nanmedian(s.fluxes)

        df = pd.DataFrame({
            'wavelength':
            s.wavelengths,
            'flux':
            s.fluxes / normfac,
            'id':
            s.id,
            'telescope':
            s.instrument.telescope.name,
            'instrument':
            s.instrument.name,
            'date_observed':
            s.observed_at.date().isoformat(),
            'pi': (s.assignment.run.pi if s.assignment is not None else
                   (s.followup_request.allocation.pi
                    if s.followup_request is not None else "")),
        })
        data.append(df)
    data = pd.concat(data)

    dfs = []
    for i, s in enumerate(spectra):
        # Smooth the spectrum by using a rolling average
        df = (pd.DataFrame({
            'wavelength': s.wavelengths,
            'flux': s.fluxes
        }).rolling(2).mean(numeric_only=True).dropna())
        dfs.append(df)

    smoothed_data = pd.concat(dfs)

    split = data.groupby('id')
    hover = HoverTool(tooltips=[
        ('wavelength', '$x'),
        ('flux', '$y'),
        ('telesecope', '@telescope'),
        ('instrument', '@instrument'),
        ('UTC date observed', '@date_observed'),
        ('PI', '@pi'),
    ])
    smoothed_max = np.max(smoothed_data['flux'])
    smoothed_min = np.min(smoothed_data['flux'])
    ymax = smoothed_max * 1.05
    ymin = smoothed_min - 0.05 * (smoothed_max - smoothed_min)
    xmin = np.min(data['wavelength']) - 100
    xmax = np.max(data['wavelength']) + 100
    plot = figure(
        aspect_ratio=2,
        sizing_mode='scale_width',
        y_range=(ymin, ymax),
        x_range=(xmin, xmax),
        tools='box_zoom,wheel_zoom,pan,reset',
        active_drag='box_zoom',
    )
    plot.add_tools(hover)
    model_dict = {}
    for i, (key, df) in enumerate(split):
        model_dict['s' + str(i)] = plot.step(
            x='wavelength',
            y='flux',
            color=color_map[key],
            source=ColumnDataSource(df),
            mode="center",
        )
    plot.xaxis.axis_label = 'Wavelength (Å)'
    plot.yaxis.axis_label = 'Flux'
    plot.toolbar.logo = None

    # TODO how to choose a good default?
    plot.y_range = Range1d(0, 1.03 * data.flux.max())

    spec_labels = []
    for k, _ in split:
        s = Spectrum.query.get(k)
        label = f'{s.instrument.telescope.nickname}/{s.instrument.name} ({s.observed_at.date().isoformat()})'
        spec_labels.append(label)

    toggle = CheckboxWithLegendGroup(
        labels=spec_labels,
        active=list(range(len(spectra))),
        colors=[color_map[k] for k, df in split],
        width=width // 5,
    )
    toggle.js_on_click(
        CustomJS(
            args={
                'toggle': toggle,
                **model_dict
            },
            code="""
          for (let i = 0; i < toggle.labels.length; i++) {
              eval("s" + i).visible = (toggle.active.includes(i))
          }
    """,
        ), )

    z_title = Div(text="Redshift (<i>z</i>): ")
    z_slider = Slider(
        value=obj.redshift if obj.redshift is not None else 0.0,
        start=0.0,
        end=1.0,
        step=0.001,
        show_value=False,
        format="0[.]000",
    )
    z_textinput = TextInput(
        value=str(obj.redshift if obj.redshift is not None else 0.0))
    z_slider.js_on_change(
        'value',
        CustomJS(
            args={
                'slider': z_slider,
                'textinput': z_textinput
            },
            code="""
            textinput.value = parseFloat(slider.value).toFixed(3);
            textinput.change.emit();
        """,
        ),
    )
    z = column(z_title, z_slider, z_textinput)

    v_title = Div(text="<i>V</i><sub>expansion</sub> (km/s): ")
    v_exp_slider = Slider(
        value=0.0,
        start=0.0,
        end=3e4,
        step=10.0,
        show_value=False,
    )
    v_exp_textinput = TextInput(value='0')
    v_exp_slider.js_on_change(
        'value',
        CustomJS(
            args={
                'slider': v_exp_slider,
                'textinput': v_exp_textinput
            },
            code="""
            textinput.value = parseFloat(slider.value).toFixed(0);
            textinput.change.emit();
        """,
        ),
    )
    v_exp = column(v_title, v_exp_slider, v_exp_textinput)

    for i, (wavelengths, color) in enumerate(SPEC_LINES.values()):
        el_data = pd.DataFrame({'wavelength': wavelengths})
        obj_redshift = 0 if obj.redshift is None else obj.redshift
        el_data['x'] = el_data['wavelength'] * (1.0 + obj_redshift)
        model_dict[f'el{i}'] = plot.segment(
            x0='x',
            x1='x',
            # TODO change limits
            y0=0,
            y1=1e4,
            color=color,
            source=ColumnDataSource(el_data),
        )
        model_dict[f'el{i}'].visible = False

    # Split spectral line legend into columns
    columns = 7
    element_dicts = zip(*itertools.zip_longest(*[iter(SPEC_LINES.items())] *
                                               columns))

    elements_groups = []  # The Bokeh checkbox groups
    callbacks = []  # The checkbox callbacks for each element
    for column_idx, element_dict in enumerate(element_dicts):
        element_dict = [e for e in element_dict if e is not None]
        labels = [key for key, value in element_dict]
        colors = [c for key, (w, c) in element_dict]
        elements = CheckboxWithLegendGroup(labels=labels,
                                           active=[],
                                           colors=colors,
                                           width=width // (columns + 1))
        elements_groups.append(elements)

        callback = CustomJS(
            args={
                'elements': elements,
                'z': z_textinput,
                'v_exp': v_exp_textinput,
                **model_dict,
            },
            code=f"""
            let c = 299792.458; // speed of light in km / s
            const i_max = {column_idx} +  {columns} * elements.labels.length;
            let local_i = 0;
            for (let i = {column_idx}; i < i_max; i = i + {columns}) {{
                let el = eval("el" + i);
                el.visible = (elements.active.includes(local_i))
                el.data_source.data.x = el.data_source.data.wavelength.map(
                    x_i => (x_i * (1 + parseFloat(z.value)) /
                                    (1 + parseFloat(v_exp.value) / c))
                );
                el.data_source.change.emit();
                local_i++;
            }}
        """,
        )
        elements.js_on_click(callback)
        callbacks.append(callback)

    z_textinput.js_on_change(
        'value',
        CustomJS(
            args={
                'z': z_textinput,
                'slider': z_slider,
                'v_exp': v_exp_textinput,
                **model_dict,
            },
            code="""
            // Update slider value to match text input
            slider.value = parseFloat(z.value).toFixed(3);
        """,
        ),
    )

    v_exp_textinput.js_on_change(
        'value',
        CustomJS(
            args={
                'z': z_textinput,
                'slider': v_exp_slider,
                'v_exp': v_exp_textinput,
                **model_dict,
            },
            code="""
            // Update slider value to match text input
            slider.value = parseFloat(v_exp.value).toFixed(3);
        """,
        ),
    )

    # Update the element spectral lines as well
    for callback in callbacks:
        z_textinput.js_on_change('value', callback)
        v_exp_textinput.js_on_change('value', callback)

    row1 = row(plot, toggle)
    row2 = row(elements_groups)
    row3 = row(z, v_exp)
    layout = column(row1, row2, row3, width=width)
    return bokeh_embed.json_item(layout)
Exemplo n.º 7
0
def bokeh_plot(import_df):
    import pandas as pd
    import numpy as np
    from bokeh.plotting import figure, show
    from bokeh.layouts import layout, widgetbox, row, column, gridplot
    from bokeh.models import ColumnDataSource, HoverTool, BoxZoomTool, ResetTool, PanTool, CustomJS, PrintfTickFormatter, WheelZoomTool, SaveTool, LassoSelectTool, NumeralTickFormatter
    from bokeh.models.widgets import Slider, Select, TextInput, Div, Tabs, Panel, DataTable, DateFormatter, TableColumn, PreText, NumberFormatter, RangeSlider
    from bokeh.io import curdoc
    from functools import lru_cache
    from bokeh.transform import dodge
    from os.path import dirname, join
    from bokeh.core.properties import value

    #load plotting data here
    @lru_cache()
    def load_data():
        df = import_df
        df.dropna(how='all', axis=0)
        #Northest=['3229','3277','3276','3230','3259','All_Stores_NE']
        df.location_reference_id = df.location_reference_id.astype(str)
        #df['region'] = ['Northeast' if x in Northest else 'Midwest' for x in df['location_reference_id']]
        df['date'] = pd.to_datetime(df['date'])
        df[[
            'BOH_gt_Shelf_Capacity', 'OTL_gt_Shelf_Capacity',
            'Ideal_BOH_gt_Shelf_Capacity', 'BOH_lt_Ideal', 'BOH_eq_Ideal',
            'BOH_gt_Ideal', 'Demand_Fulfilled', 'Fill_Rate', 'Backroom_OH',
            'Total_OH', 'Prop_OH_in_Backroom', 'Never_Q98_gt_POG',
            'Never_Ideal_BOH_gt_POG', 'Sometimes_OTL_Casepack_1_gt_POG',
            'Always_OTL_Casepack_1_le_POG', 'Non_POG'
        ]] = df[[
            'BOH > Shelf Capacity', 'OTL > Shelf Capacity',
            'Ideal BOH > Shelf Capacity', 'BOH < Ideal', 'BOH = Ideal',
            'BOH > Ideal', 'Demand Fulfilled', 'Fill Rate', 'Backroom_OH',
            'Total OH', 'Prop OH in Backroom', 'Never: Q98 > POG',
            'Never: Ideal BOH > POG', 'Sometimes: OTL+Casepack-1 > POG',
            'Always: OTL+Casepack-1 <= POG', 'Non-POG'
        ]]
        df['date_bar'] = df['date']
        df['date_bar'] = df['date_bar'].astype(str)
        return df

    #Filter data source for "All" stores OR data agrregation on DC level
    df_agg = load_data().groupby(['location_reference_id'],
                                 as_index=False).sum()
    source1 = ColumnDataSource(data=df_agg)
    sdate = min(load_data()['date'])
    edate = max(load_data()['date'])
    nodes = len(list(load_data().location_reference_id.unique()))
    days = len(list(load_data().date.unique()))
    policy = "Prod"

    #list of dates for vbar charts
    x_range_list = list(load_data().date_bar.unique())
    #direct access to number of location_reference_idand region
    all_locations1 = list(load_data().location_reference_id.unique())
    #agg_value=['All']
    #all location_reference_idfrom csv file along with an option for agg data "All"
    #all_locations=all_locations1+agg_value
    #all_regions = ['Northeast', 'Midwest']
    all_regions = list(load_data().region.unique())

    desc = Div(text="All locations", width=230)
    pre = Div(text="_", width=230)
    location = Select(title="Location",
                      options=all_locations1,
                      value="All_Stores_NE")
    region = Select(title="Region", options=all_regions, value="NE")

    text_input = TextInput(value="default", title="Search Location:")
    #full data set from load_data(df=df_import)
    source = ColumnDataSource(data=load_data())
    original_source = ColumnDataSource(data=load_data())

    #plotting starts........... here are total 8 graphs for each Metric.

    #Back room on hand
    hover = HoverTool(
        tooltips=[("Location", "@location_reference_id"), (
            "Date", "@date_bar"), ("Backroom_OH", "@Backroom_OH{0,0.00}")])
    TOOLS = [
        hover,
        BoxZoomTool(),
        LassoSelectTool(),
        WheelZoomTool(),
        PanTool(),
        ResetTool(),
        SaveTool()
    ]
    p = figure(x_range=x_range_list,
               plot_width=1000,
               plot_height=525,
               title="Backroom On hand by store",
               tools=TOOLS,
               toolbar_location='above',
               x_axis_label="Date",
               y_axis_label="Backroom OH")
    p.background_fill_color = "#e6e9ed"
    p.background_fill_alpha = 0.5
    p.vbar(x=dodge('date_bar', -0.25, range=p.x_range),
           top='Backroom_OH',
           hover_alpha=0.5,
           hover_line_color='black',
           width=0.8,
           source=source,
           color="#718dbf")
    p.xaxis.major_label_orientation = 1
    p.legend.border_line_width = 3
    p.legend.border_line_color = None
    p.legend.border_line_alpha = 0.5
    p.title.text_color = "olive"

    #inbound outbound
    hover_m = HoverTool(
        tooltips=[("Location", "@location_reference_id"), (
            "Date", "@date_bar"), (
                "Inbound",
                "@Inbound{0,0.00}"), ("Outbound", "@Outbound{0,0.00}")])
    TOOLS_m = [
        hover_m,
        BoxZoomTool(),
        LassoSelectTool(),
        WheelZoomTool(),
        PanTool(),
        ResetTool(),
        SaveTool()
    ]
    m = figure(plot_height=525,
               plot_width=1000,
               x_range=x_range_list,
               title="Inbound/Outbound by store",
               tools=TOOLS_m,
               toolbar_location='above',
               x_axis_label="Date",
               y_axis_label="Units")
    m.background_fill_color = "#e6e9ed"
    m.background_fill_alpha = 0.5
    m.vbar(x=dodge('date_bar', -0.25, range=m.x_range),
           top='Inbound',
           hover_alpha=0.5,
           hover_line_color='black',
           width=0.4,
           source=source,
           color="#718dbf",
           legend=value("Inbound"))
    m.vbar(x=dodge('date_bar', 0.25, range=m.x_range),
           top='Outbound',
           hover_alpha=0.5,
           hover_line_color='black',
           width=0.4,
           source=source,
           color="#e84d60",
           legend=value("Outbound"))
    m.xaxis.major_label_orientation = 1
    m.legend.border_line_width = 3
    m.legend.border_line_color = None
    m.legend.border_line_alpha = 0.5
    m.title.text_color = "olive"

    #Stockout
    hover_s = HoverTool(
        tooltips=[("Location", "@location_reference_id"), (
            "Date", "@date_bar"), (
                "BOH_OOS",
                "@BOH_OOS{0,0.000}"), ("EOH_OOS", "@EOH_OOS{0,0.000}")])
    TOOLS_s = [
        hover_s,
        BoxZoomTool(),
        LassoSelectTool(),
        WheelZoomTool(),
        PanTool(),
        ResetTool(),
        SaveTool()
    ]
    s = figure(plot_height=525,
               plot_width=1000,
               title="Stockouts by store",
               x_axis_type="datetime",
               toolbar_location='above',
               tools=TOOLS_s,
               x_axis_label="Date",
               y_axis_label="Prop Stockout")
    s.background_fill_color = "#e6e9ed"
    s.background_fill_alpha = 0.5
    s.circle(x='date',
             y='EOH_OOS',
             source=source,
             fill_color=None,
             line_color="#4375c6")
    s.line(x='date',
           y='EOH_OOS',
           source=source,
           hover_alpha=0.5,
           hover_line_color='black',
           line_width=2,
           line_color='navy',
           legend=value("EOH OOS"))
    s.circle(x='date',
             y='BOH_OOS',
             source=source,
             fill_color=None,
             line_color="#4375c6")
    s.line(x='date',
           y='BOH_OOS',
           source=source,
           hover_alpha=0.5,
           hover_line_color='black',
           line_width=2,
           line_color='red',
           legend=value("BOH OOS"))
    s.legend.border_line_width = 3
    s.legend.border_line_color = None
    s.legend.border_line_alpha = 0.5
    s.title.text_color = "olive"

    #Fill rate
    hover_t = HoverTool(
        tooltips=[("Location", "@location_reference_id"), (
            "Date", "@date_bar"), ("Fill Rate", "@Fill_Rate{0,0.00}")])
    TOOLS_t = [
        hover_t,
        BoxZoomTool(),
        LassoSelectTool(),
        WheelZoomTool(),
        PanTool(),
        ResetTool(),
        SaveTool()
    ]
    t = figure(plot_height=525,
               x_range=x_range_list,
               plot_width=1000,
               title="Fill rates by store",
               tools=TOOLS_t,
               toolbar_location='above',
               x_axis_label="Date",
               y_axis_label="Fill rate")
    t.background_fill_color = "#e6e9ed"
    t.background_fill_alpha = 0.5
    t.vbar(x=dodge('date_bar', -0.25, range=t.x_range),
           top='Fill Rate',
           hover_alpha=0.5,
           hover_line_color='black',
           width=0.8,
           source=source,
           color="#718dbf")
    t.xaxis.major_label_orientation = 1
    t.legend.border_line_width = 3
    t.legend.border_line_color = None
    t.legend.border_line_alpha = 0.5
    t.title.text_color = "olive"

    # % Backroom spillover
    hover_w = HoverTool(
        tooltips=[("Location", "@location_reference_id"), ("Date",
                                                           "@date_bar"),
                  ("Prop OH in Backroom", "@Prop_OH_in_Backroom{0,0.00}")])
    TOOLS_w = [
        hover_w,
        BoxZoomTool(),
        LassoSelectTool(),
        WheelZoomTool(),
        PanTool(),
        ResetTool(),
        SaveTool()
    ]
    w = figure(plot_height=525,
               plot_width=1000,
               title="Prop OH in Backroom by store",
               x_axis_type="datetime",
               tools=TOOLS_w,
               toolbar_location='above',
               x_axis_label="Date",
               y_axis_label=" % Backroom spillover")
    w.background_fill_color = "#e6e9ed"
    w.background_fill_alpha = 0.5
    w.circle(x='date',
             y='Prop OH in Backroom',
             source=source,
             fill_color=None,
             line_color="#4375c6")
    w.line(x='date',
           y='Prop OH in Backroom',
           source=source,
           hover_alpha=0.5,
           hover_line_color='black',
           line_width=2,
           line_color='navy')
    w.title.text_font_style = "bold"
    w.title.text_color = "olive"
    w.legend.click_policy = "hide"
    w.yaxis[0].formatter = NumeralTickFormatter(format="0.0%")

    #BOH vs Ideal
    hover_f = HoverTool(
        tooltips=[("Location", "@location_reference_id"), (
            "Date",
            "@date_bar"), ('BOH < Ideal', "@BOH_lt_Ideal{0,0.00}"
                           ), ('BOH > Ideal', "@BOH_gt_Ideal{0,0.00}"
                               ), ('BOH = Ideal', "@BOH_eq_Ideal{0,0.00}")])
    TOOLS_f = [
        hover_f,
        BoxZoomTool(),
        LassoSelectTool(),
        WheelZoomTool(),
        PanTool(),
        ResetTool(),
        SaveTool()
    ]
    colors = ["#c9d9d3", "#718dbf", "#e84d60"]
    BOH_vs_ideal = ['BOH < Ideal', 'BOH > Ideal', 'BOH = Ideal']
    f = figure(x_range=x_range_list,
               plot_height=525,
               plot_width=1000,
               title="BOH vs Ideal by store",
               toolbar_location='above',
               x_axis_label="Date",
               y_axis_label="Prop",
               tools=TOOLS_f)
    f.vbar_stack(BOH_vs_ideal,
                 x='date_bar',
                 width=0.9,
                 color=colors,
                 source=source,
                 legend=[value(x) for x in BOH_vs_ideal],
                 name=BOH_vs_ideal)
    f.xaxis.major_label_orientation = 1
    f.legend.border_line_width = 3
    f.legend.border_line_color = None
    f.legend.border_line_alpha = 0.5
    f.title.text_color = "olive"

    #Pog Fit
    hover_g = HoverTool(
        tooltips=[("Location", "@location_reference_id"), (
            "Date",
            "@date_bar"), ('Never: Q98 > POG', "@Never_Q98_gt_POG{0,0.00}"),
                  ("Never: Ideal BOH > POG",
                   "@Never_Ideal_BOH_gt_POG{0,0.00}"),
                  ("Sometimes: OTL+Casepack-1 > POG",
                   "@Sometimes_OTL_Casepack_1_gt_POG{0,0.00}"),
                  ("Always: OTL+Casepack-1 <= POG",
                   "@Always_OTL_Casepack_1_le_POG{0,0.00}"
                   ), ("Non-POG'", "@Non_POG{0,0.00}")])
    TOOLS_g = [
        hover_g,
        BoxZoomTool(),
        LassoSelectTool(),
        WheelZoomTool(),
        PanTool(),
        ResetTool(),
        SaveTool()
    ]
    colors2 = ['#79D151', "#718dbf", '#29788E', '#fc8d59', '#d53e4f']
    pog_fit = [
        'Never: Q98 > POG', 'Never: Ideal BOH > POG',
        'Sometimes: OTL+Casepack-1 > POG', 'Always: OTL+Casepack-1 <= POG',
        'Non-POG'
    ]
    g = figure(x_range=x_range_list,
               plot_height=525,
               plot_width=1200,
               title="Pog Fit by store",
               toolbar_location='above',
               x_axis_label="Date",
               y_axis_label="Counts",
               tools=TOOLS_g)
    g.vbar_stack(pog_fit,
                 x='date_bar',
                 width=0.9,
                 color=colors2,
                 source=source,
                 legend=[value(x) for x in pog_fit],
                 name=pog_fit)
    g.xaxis.major_label_orientation = 1
    g.legend.border_line_width = 3
    g.legend.border_line_color = None
    g.legend.border_line_alpha = 0.5
    g.title.text_color = "olive"
    g.legend.location = "top_right"

    # BOH vs Pog
    colors3 = ["#c9d9d3", "#718dbf", "#e84d60"]
    shelf = [
        'BOH > Shelf Capacity', 'OTL > Shelf Capacity',
        'Ideal BOH > Shelf Capacity'
    ]
    hover_h = HoverTool(
        tooltips=[("Location", "@location_reference_id"), ("Date",
                                                           "@date_bar"),
                  ("OTL > Shelf Capacity", "@OTL_gt_Shelf_Capacity{0,0.00}"
                   ), ("BOH > Shelf Capacity",
                       "@BOH_gt_Shelf_Capacity{0,0.00}"),
                  ("Ideal BOH > Shelf Capacity",
                   "@Ideal_BOH_gt_Shelf_Capacity{0,0.00}")])
    TOOLS_h = [
        hover_h,
        BoxZoomTool(),
        LassoSelectTool(),
        WheelZoomTool(),
        PanTool(),
        ResetTool(),
        SaveTool()
    ]
    h = figure(plot_height=525,
               plot_width=1000,
               title="BOH vs Pog by store",
               x_axis_type="datetime",
               toolbar_location='above',
               tools=TOOLS_h,
               x_axis_label="Date",
               y_axis_label="Prop")
    h.background_fill_color = "#e6e9ed"
    h.background_fill_alpha = 0.5
    h.circle(x='date',
             y='BOH > Shelf Capacity',
             source=source,
             fill_color=None,
             line_color="#4375c6")
    h.line(x='date',
           y='BOH > Shelf Capacity',
           source=source,
           hover_alpha=0.5,
           hover_line_color='black',
           line_width=2,
           line_color='navy',
           legend=value("BOH > Shelf Capacity"))
    h.circle(x='date',
             y='OTL > Shelf Capacity',
             source=source,
             fill_color=None,
             line_color="#4375c6")
    h.line(x='date',
           y='OTL > Shelf Capacity',
           source=source,
           hover_alpha=0.5,
           hover_line_color='black',
           line_width=2,
           line_color="green",
           legend=value("OTL > Shelf Capacity"))
    h.circle(x='date',
             y='Ideal BOH > Shelf Capacity',
             source=source,
             fill_color=None,
             line_color="#4375c6")
    h.line(x='date',
           y='Ideal BOH > Shelf Capacity',
           source=source,
           hover_alpha=0.5,
           hover_line_color='black',
           line_width=2,
           line_color="#e84d60",
           legend=value("Ideal BOH > Shelf Capacity"))
    h.legend.border_line_width = 3
    h.legend.border_line_color = None
    h.legend.border_line_alpha = 0.5
    h.title.text_color = "olive"
    h.legend.click_policy = "mute"

    # Inventory
    hover_j = HoverTool(
        tooltips=[("Location", "@location_reference_id"), (
            "Date", "@date_bar"), ("DFE_Q98", "@DFE_Q98{0,0.00}"),
                  ("OTL",
                   "@OTL{0,0.00}"), ("EOH",
                                     "@EOH{0,0.00}"), ("BOH", "@BOH{0,0.00}")])
    TOOLS_j = [
        hover_j,
        BoxZoomTool(),
        LassoSelectTool(),
        WheelZoomTool(),
        PanTool(),
        ResetTool(),
        SaveTool()
    ]
    j = figure(plot_height=525,
               plot_width=1200,
               x_range=x_range_list,
               title="Inbound/Outbound by store",
               tools=TOOLS_j,
               toolbar_location='above',
               x_axis_label="Date",
               y_axis_label="Units")
    j.background_fill_color = "#e6e9ed"
    j.background_fill_alpha = 0.5
    j.vbar(x=dodge('date_bar', -0.40, range=j.x_range),
           top='DFE_Q98',
           hover_alpha=0.3,
           hover_line_color='black',
           width=0.2,
           source=source,
           color="#FBA40A",
           legend=value("DFE_Q98"))
    j.vbar(x=dodge('date_bar', -0.20, range=j.x_range),
           top='OTL',
           hover_alpha=0.3,
           hover_line_color='black',
           width=0.2,
           source=source,
           color="#4292c6",
           legend=value("OTL"))
    j.vbar(x=dodge('date_bar', 0.00, range=j.x_range),
           top='EOH',
           hover_alpha=0.3,
           hover_line_color='black',
           width=0.2,
           source=source,
           color='#a1dab4',
           legend=value("EOH"))
    j.vbar(x=dodge('date_bar', 0.20, range=j.x_range),
           top='BOH',
           hover_alpha=0.3,
           hover_line_color='black',
           width=0.2,
           source=source,
           color="#DC5039",
           legend=value("BOH"))
    j.xaxis.major_label_orientation = 1
    j.legend.border_line_width = 3
    j.legend.border_line_color = None
    j.legend.border_line_alpha = 0.5
    j.title.text_color = "olive"
    j.legend.location = "top_left"
    j.legend.click_policy = "mute"

    #desc.text = " <br >  <b> Region:</b> <i>  </i> <br /> "
    pre.text = " <b>Start date:</b>  <i>{}</i> <br />  <b>End date:</b> <i>{}</i> <br /> <b>Time period:</b> <i>{}</i> days <br />  <b> Total Number of Nodes:</b> <i>{}</i> <br /> <b>Policy</b> = <i>{}</i><br /> ".format(
        sdate, edate, days, nodes, policy)

    #fuction to update data on selection
    callback = CustomJS(args=dict(source=source,
                                  original_source=original_source,
                                  location_select_obj=location,
                                  region_select_obj=region,
                                  div=desc,
                                  text_input=text_input),
                        code="""
    var data = source.get('data');
    var original_data = original_source.get('data');
    var loc = location_select_obj.get('value');
    var reg = region_select_obj.get('value');
    var line = " <br />  <b> Region:</b>"+ reg + "<br />  <b>Location:</b> " +   loc;
    var text_input =text_input.get('value');
    div.text=line;
    for (var key in original_data) {
    data[key] = [];
    for (var i = 0; i < original_data['location_reference_id'].length; ++i) {
    if ((original_data['location_reference_id'][i] === loc) && (original_data['region'][i] === reg) ) {
    data[key].push(original_data[key][i]);
    }   }   }
    source.trigger('change');
    """)

    #controls = [location, region]
    #for control in controls:
    #control.js_on_change("value", callback)
    #source.js_on_change("value", callback)
    desc.js_on_event('event', callback)
    location.js_on_change('value', callback)
    region.js_on_change('value', callback)
    text_input.js_on_change('value', callback)
    #inputs = widgetbox(*controls, sizing_mode="fixed")
    #inputs = widgetbox(*controls,width=220,height=500)
    inputs = widgetbox(location, region, desc, pre, width=220, height=500)
    # controls number of tabs
    tab1 = Panel(child=p, title='Backroom OH')
    tab2 = Panel(child=s, title='Stockouts')
    tab3 = Panel(child=f, title='BOH vs Ideal')
    tab4 = Panel(child=g, title='Pog Fit')
    tab5 = Panel(child=m, title='Inbound/Outbound')
    tab6 = Panel(child=h, title='BOH vs POG')
    tab7 = Panel(child=t, title='Fill Rate')
    tab8 = Panel(child=j, title='Inventory')
    tab9 = Panel(child=w, title='Prop OH in Backroom')

    #data table columns to summarize data
    columns = [
        TableColumn(field="location_reference_id", title="Location"),
        TableColumn(field="Backroom_OH",
                    title="Backroom_OH",
                    formatter=NumberFormatter(format="0,0")),
        TableColumn(field="Outbound",
                    title="Outbound",
                    formatter=NumberFormatter(format="0,0")),
        TableColumn(field="Inbound",
                    title="Inbound",
                    formatter=NumberFormatter(format="0,0")),
        TableColumn(field="OTL",
                    title="OTL",
                    formatter=NumberFormatter(format="0,0")),
        TableColumn(field="DFE_Q98",
                    title="DFE_Q98",
                    formatter=NumberFormatter(format="0,0")),
        TableColumn(field="BOH",
                    title="BOH",
                    formatter=NumberFormatter(format="0,0")),
        TableColumn(field="EOH",
                    title="EOH",
                    formatter=NumberFormatter(format="0,0")),
        TableColumn(field="BOH_OOS",
                    title="BOH_OOS",
                    formatter=NumberFormatter(format="0,0")),
        TableColumn(field="EOH_OOS",
                    title="EOH_OOS",
                    formatter=NumberFormatter(format="0,0"))
    ]
    data_table = DataTable(source=source1, columns=columns, width=1250)

    tab10 = Panel(child=data_table, title='Summary Table')
    view = Tabs(
        tabs=[tab1, tab2, tab5, tab8, tab6, tab3, tab7, tab4, tab9, tab10])

    layout_text = column(inputs)
    layout1 = row(layout_text, view)
    #laying out plot
    layout2 = layout(children=[[layout_text, view]],
                     sizing_mode='scale_height')
    #update plots
    return layout2
Exemplo n.º 8
0
def main():
    print('''Please select the CSV dataset you\'d like to use.
    The dataset should contain these columns:
        - metric to apply threshold to
        - indicator of event to detect (e.g. malicious activity)
            - Please label this as 1 or 0 (true or false); 
            This will not work otherwise!
    ''')
    # Import the dataset
    imported_data = None
    while isinstance(imported_data, pd.DataFrame) == False:
        file_path = input('Enter the path of your dataset: ')
        imported_data = file_to_df(file_path)

    time.sleep(1)

    print(f'''\nGreat! Here is a preview of your data:
Imported fields:''')
    # List headers by column index.
    cols = list(imported_data.columns)
    for index in range(len(cols)):
        print(f'{index}: {cols[index]}')
    print(f'Number of records: {len(imported_data.index)}\n')
    # Preview the DataFrame
    time.sleep(1)
    print(imported_data.head(), '\n')

    # Prompt for the metric and source of truth.
    time.sleep(1)
    metric_col, indicator_col = columns_picker(cols)
    # User self-validation.
    col_check = input('Can you confirm if this is correct? (y/n): ').lower()
    # If it's wrong, let them try again
    while col_check != 'y':
        metric_col, indicator_col = columns_picker(cols)
        col_check = input(
            'Can you confirm if this is correct? (y/n): ').lower()
    else:
        print(
            '''\nGreat! Thanks for your patience. Generating summary stats now..\n'''
        )

    # Generate summary stats.
    time.sleep(1)
    malicious, normal = classification_split(imported_data, metric_col,
                                             indicator_col)
    mal_mean = malicious.mean()
    mal_stddev = malicious.std()
    mal_count = malicious.size
    mal_median = malicious.median()
    norm_mean = normal.mean()
    norm_stddev = normal.std()
    norm_count = normal.size
    norm_median = normal.median()

    print(f'''Normal vs Malicious Summary (metric = {metric_col}):
Normal:
-----------------------------
Observations: {round(norm_count, 2)}
Average: {round(norm_mean, 2)}
Median: {round(norm_median, 2)}
Standard Deviation: {round(norm_stddev, 2)}

Malicious:
-----------------------------
Observations: {round(mal_count, 2)}
Average: {round(mal_mean, 2)}
Median: {round(mal_median, 2)}
Standard Deviation: {round(mal_stddev, 2)}
''')

    # Insights and advisories
    # Provide the accuracy metrics of a generic threshold at avg + 3 std deviations
    generic_threshold = confusion_matrix(
        malicious, normal, threshold_calc(norm_mean, norm_stddev, 3))

    time.sleep(1)
    print(
        f'''A threshold at (average + 3x standard deviations) {metric_col} would result in:
    - True Positives (correctly identified malicious events: {generic_threshold['TP']:,}
    - False Positives (wrongly identified normal events: {generic_threshold['FP']:,}
    - True Negatives (correctly identified normal events: {generic_threshold['TN']:,}
    - False Negatives (wrongly identified malicious events: {generic_threshold['FN']:,}

    Accuracy Metrics:
    - Precision (what % of events above threshold are actually malicious): {round(generic_threshold['precision'] * 100, 1)}%
    - Recall (what % of malicious events did we catch): {round(generic_threshold['recall'] * 100, 1)}%
    - F1 Score (blends precision and recall): {round(generic_threshold['f1_score'] * 100, 1)}%'''
    )

    # Distribution skew check.
    if norm_mean >= (norm_median * 1.1):
        time.sleep(1)
        print(
            f'''\nYou may want to be cautious as your normal traffic\'s {metric_col} 
has a long tail towards high values. The median is {round(norm_median, 2)} 
compared to {round(norm_mean, 2)} for the average.''')

    if mal_mean < threshold_calc(norm_mean, norm_stddev, 2):
        time.sleep(1)
        print(
            f'''\nWarning: you may find it difficult to avoid false positives as the average
{metric_col} for malicious traffic is under the 95th percentile of the normal traffic.'''
        )

    # For fun/anticipation. Actually a nerd joke because of the method we'll be using.
    if '-q' not in sys.argv[1:]:
        time.sleep(1)
        play_a_game.billy()
        decision = input('yes/no: ').lower()
        while decision != 'yes':
            time.sleep(1)
            print('...That\'s no fun...')
            decision = input('Let\'s try that again: ').lower()

    # Let's get to the simulations!
    time.sleep(1)
    print('''\nInstead of manually experimenting with threshold multipliers, 
let\'s simulate a range of options and see what produces the best result. 
This is similar to what is known as \"Monte Carlo simulation\".\n''')

    # Initialize session name & create app folder if there isn't one.
    time.sleep(1)
    session_name = input('Please provide a name for this project/session: ')
    session_folder = make_folder(session_name)

    # Generate list of multipliers to iterate over.
    time.sleep(1)
    mult_start = float(
        input(
            'Please provide the minimum multiplier you want to start at. We recommend 2: '
        ))
    # Set the max to how many std deviations away the sample max is.
    mult_end = (imported_data[metric_col].max() - norm_mean) / norm_stddev
    mult_interval = float(
        input('Please provide the desired gap between multiplier options: '))
    # range() only allows integers, let's manually populate a list
    multipliers = []
    mult_counter = mult_start
    while mult_counter < mult_end:
        multipliers.append(round(mult_counter, 2))
        mult_counter += mult_interval
    print('Generating simulations..\n')

    # Run simulations using our multipliers.
    simulations = monte_carlo(malicious, normal, norm_mean, norm_stddev,
                              multipliers)
    print('Done!')
    time.sleep(1)

    # Save simulations as CSV for later use.
    simulation_filepath = os.path.join(
        session_folder, f'{session_name}_simulation_results.csv')
    simulations.to_csv(simulation_filepath, index=False)
    print(f'Saved results to: {simulation_filepath}')
    # Find the first threshold with the highest F1 score.
    # This provides a balanced approach between precision and recall.
    f1_max = simulations[simulations.f1_score ==
                         simulations.f1_score.max()].head(1)
    f1_max_mult = f1_max.squeeze()['multiplier']
    time.sleep(1)
    print(
        f'''\nBased on the F1 score metric, setting a threshold at {round(f1_max_mult,1)} standard deviations
above the average magnitude might provide optimal results.\n''')
    time.sleep(1)
    print(f'''{f1_max}

We recommend that you skim the CSV and the following visualization outputs 
to sanity check results and make your own judgement.
''')

    # Now for the fun part..generating the visualizations via Bokeh.

    # Header & internal CSS.
    title_text = '''
    <style>

    @font-face {
        font-family: RobotoBlack;
        src: url(fonts/Roboto-Black.ttf);
        font-weight: bold;
    }

    
     @font-face {
        font-family: RobotoBold;
        src: url(fonts/Roboto-Bold.ttf);
        font-weight: bold;
    }   
    
    @font-face {
        font-family: RobotoRegular;
        src: url(fonts/Roboto-Regular.ttf);
    }

    body {
        background-color: #f2ebe6;
    }

    title_header {
        font-size: 80px;
        font-style: bold;
        font-family: RobotoBlack, Helvetica;
        font-weight: bold;
        margin-bottom: -200px;
    }

    h1, h2, h3 {
        font-family: RobotoBlack, Helvetica;
        color: #313596;
    }

    p {
        font-size: 12px;
        font-family: RobotoRegular
    }

    b {
        color: #58c491;
    }

    th, td {
        text-align:left;
        padding: 5px;
    }

    tr:nth-child(even) {
        background-color: white;
        opacity: .7;
    }

    .vertical { 
        border-left: 1px solid black; 
        height: 190px; 
            } 
    </style>

        <title_header style="text-align:left; color: white;">
            Cream.
        </title_header>
        <p style="font-family: RobotoBold, Helvetica;
        font-size:18px;
        margin-top: 0px;
        margin-left: 5px;">
            Because time is money, and <b style="font-size=18px;">"Cash Rules Everything Around Me"</b>.
        </p>
    </div>
    '''

    title_div = Div(text=title_text,
                    width=800,
                    height=160,
                    margin=(40, 0, 0, 70))

    # Summary stats from earlier.
    summary_text = f'''
    <h1>Results Overview</h1> 
    <i>metric = magnitude</i>

    <table style="width:100%">
      <tr>
        <th>Metric</th>
        <th>Normal Events</th>
        <th>Malicious Events</th>
      </tr>
      <tr>
        <td>Observations</td>
        <td>{norm_count:,}</td>
        <td>{mal_count:,}</td>
      </tr>
      <tr>
        <td>Average</td>
        <td>{round(norm_mean, 2):,}</td>
        <td>{round(mal_mean, 2):,}</td>
      </tr>
      <tr>
        <td>Median</td>
        <td>{round(norm_median, 2):,}</td>
        <td>{round(mal_median, 2):,}</td>
      </tr> 
      <tr>
        <td>Standard Deviation</td>
        <td>{round(norm_stddev, 2):,}</td>
        <td>{round(mal_stddev, 2):,}</td>
      </tr> 
    </table>
    '''

    summary_div = Div(text=summary_text,
                      width=470,
                      height=320,
                      margin=(3, 0, -70, 73))

    # Results of the hypothetical threshold.
    hypothetical = f'''
    <h1>"Rule of thumb" Hypothetical Threshold</h1>
    <p>A threshold at <i>(average + 3x standard deviations)</i> {metric_col} would result in:</p>
    <ul>
        <li>True Positives (correctly identified malicious events: 
            <b>{generic_threshold['TP']:,}</b></li>
        <li>False Positives (wrongly identified normal events:
            <b>{generic_threshold['FP']:,}</b></li>
        <li>True Negatives (correctly identified normal events: 
            <b>{generic_threshold['TN']:,}</b></li>
        <li>False Negatives (wrongly identified malicious events: 
            <b>{generic_threshold['FN']:,}</b></li>
    </ul>
    <h2>Accuracy Metrics</h2>
    <ul>
        <li>Precision (what % of events above threshold are actually malicious): 
            <b>{round(generic_threshold['precision'] * 100, 1)}%</b></li>
        <li>Recall (what % of malicious events did we catch): 
            <b>{round(generic_threshold['recall'] * 100, 1)}%</b></li>
        <li>F1 Score (blends precision and recall): 
            <b>{round(generic_threshold['f1_score'] * 100, 1)}%</b></li>
    </ul>
    '''

    hypo_div = Div(text=hypothetical,
                   width=600,
                   height=320,
                   margin=(5, 0, -70, 95))

    line = '''
    <div class="vertical"></div>
    '''
    vertical_line = Div(text=line,
                        width=20,
                        height=320,
                        margin=(80, 0, -70, -10))

    # Let's get the exploratory charts generated.

    malicious_hist, malicious_edge = np.histogram(malicious, bins=100)
    mal_hist_df = pd.DataFrame({
        'metric': malicious_hist,
        'left': malicious_edge[:-1],
        'right': malicious_edge[1:]
    })

    normal_hist, normal_edge = np.histogram(normal, bins=100)
    norm_hist_df = pd.DataFrame({
        'metric': normal_hist,
        'left': normal_edge[:-1],
        'right': normal_edge[1:]
    })

    exploratory = figure(
        plot_width=plot_width,
        plot_height=plot_height,
        sizing_mode='fixed',
        title=f'{metric_col.capitalize()} Distribution (σ = std dev)',
        x_axis_label=f'{metric_col.capitalize()}',
        y_axis_label='Observations')

    exploratory.title.text_font_size = title_font_size
    exploratory.border_fill_color = cell_bg_color
    exploratory.border_fill_alpha = cell_bg_alpha
    exploratory.background_fill_color = cell_bg_color
    exploratory.background_fill_alpha = plot_bg_alpha
    exploratory.min_border_left = left_border
    exploratory.min_border_right = right_border
    exploratory.min_border_top = top_border
    exploratory.min_border_bottom = bottom_border

    exploratory.quad(bottom=0,
                     top=mal_hist_df.metric,
                     left=mal_hist_df.left,
                     right=mal_hist_df.right,
                     legend_label='malicious',
                     fill_color=malicious_color,
                     alpha=.85,
                     line_alpha=.35,
                     line_width=.5)
    exploratory.quad(bottom=0,
                     top=norm_hist_df.metric,
                     left=norm_hist_df.left,
                     right=norm_hist_df.right,
                     legend_label='normal',
                     fill_color=normal_color,
                     alpha=.35,
                     line_alpha=.35,
                     line_width=.5)

    exploratory.add_layout(
        Arrow(end=NormalHead(fill_color=malicious_color, size=10,
                             line_alpha=0),
              line_color=malicious_color,
              x_start=mal_mean,
              y_start=mal_count,
              x_end=mal_mean,
              y_end=0))
    arrow_label = Label(x=mal_mean,
                        y=mal_count,
                        y_offset=5,
                        text='Malicious Events',
                        text_font_style='bold',
                        text_color=malicious_color,
                        text_font_size='10pt')

    exploratory.add_layout(arrow_label)
    exploratory.xaxis.formatter = NumeralTickFormatter(format='0,0')
    exploratory.yaxis.formatter = NumeralTickFormatter(format='0,0')

    # 3 sigma reference line
    sigma_ref(exploratory, norm_mean, norm_stddev)

    exploratory.legend.location = "top_right"
    exploratory.legend.background_fill_alpha = .3

    # Zoomed in version
    overlap_view = figure(
        plot_width=plot_width,
        plot_height=plot_height,
        sizing_mode='fixed',
        title=f'Overlap Highlight',
        x_axis_label=f'{metric_col.capitalize()}',
        y_axis_label='Observations',
        y_range=(0, mal_count * .33),
        x_range=(norm_mean + (norm_stddev * 2.5), mal_mean + (mal_stddev * 3)),
    )

    overlap_view.title.text_font_size = title_font_size
    overlap_view.border_fill_color = cell_bg_color
    overlap_view.border_fill_alpha = cell_bg_alpha
    overlap_view.background_fill_color = cell_bg_color
    overlap_view.background_fill_alpha = plot_bg_alpha
    overlap_view.min_border_left = left_border
    overlap_view.min_border_right = right_border
    overlap_view.min_border_top = top_border
    overlap_view.min_border_bottom = bottom_border

    overlap_view.quad(bottom=0,
                      top=mal_hist_df.metric,
                      left=mal_hist_df.left,
                      right=mal_hist_df.right,
                      legend_label='malicious',
                      fill_color=malicious_color,
                      alpha=.85,
                      line_alpha=.35,
                      line_width=.5)
    overlap_view.quad(bottom=0,
                      top=norm_hist_df.metric,
                      left=norm_hist_df.left,
                      right=norm_hist_df.right,
                      legend_label='normal',
                      fill_color=normal_color,
                      alpha=.35,
                      line_alpha=.35,
                      line_width=.5)
    overlap_view.xaxis.formatter = NumeralTickFormatter(format='0,0')
    overlap_view.yaxis.formatter = NumeralTickFormatter(format='0,0')

    sigma_ref(overlap_view, norm_mean, norm_stddev)

    overlap_view.legend.location = "top_right"
    overlap_view.legend.background_fill_alpha = .3

    # Probability Density - bigger bins for sparser malicous observations
    malicious_hist_dense, malicious_edge_dense = np.histogram(malicious,
                                                              density=True,
                                                              bins=50)
    mal_hist_dense_df = pd.DataFrame({
        'metric': malicious_hist_dense,
        'left': malicious_edge_dense[:-1],
        'right': malicious_edge_dense[1:]
    })

    normal_hist_dense, normal_edge_dense = np.histogram(normal,
                                                        density=True,
                                                        bins=100)
    norm_hist_dense_df = pd.DataFrame({
        'metric': normal_hist_dense,
        'left': normal_edge_dense[:-1],
        'right': normal_edge_dense[1:]
    })

    density = figure(plot_width=plot_width,
                     plot_height=plot_height,
                     sizing_mode='fixed',
                     title='Probability Density',
                     x_axis_label=f'{metric_col.capitalize()}',
                     y_axis_label='% of Group Total')

    density.title.text_font_size = title_font_size
    density.border_fill_color = cell_bg_color
    density.border_fill_alpha = cell_bg_alpha
    density.background_fill_color = cell_bg_color
    density.background_fill_alpha = plot_bg_alpha
    density.min_border_left = left_border
    density.min_border_right = right_border
    density.min_border_top = top_border
    density.min_border_bottom = bottom_border

    density.quad(bottom=0,
                 top=mal_hist_dense_df.metric,
                 left=mal_hist_dense_df.left,
                 right=mal_hist_dense_df.right,
                 legend_label='malicious',
                 fill_color=malicious_color,
                 alpha=.85,
                 line_alpha=.35,
                 line_width=.5)
    density.quad(bottom=0,
                 top=norm_hist_dense_df.metric,
                 left=norm_hist_dense_df.left,
                 right=norm_hist_dense_df.right,
                 legend_label='normal',
                 fill_color=normal_color,
                 alpha=.35,
                 line_alpha=.35,
                 line_width=.5)
    density.xaxis.formatter = NumeralTickFormatter(format='0,0')
    density.yaxis.formatter = NumeralTickFormatter(format='0.000%')

    sigma_ref(density, norm_mean, norm_stddev)

    density.legend.location = "top_right"
    density.legend.background_fill_alpha = .3

    # Simulation Series to be used
    false_positives = simulations.FP
    false_negatives = simulations.FN
    multiplier = simulations.multiplier
    precision = simulations.precision
    recall = simulations.recall
    f1_score = simulations.f1_score
    f1_max = simulations[simulations.f1_score == simulations.f1_score.max(
    )].head(1).squeeze()['multiplier']

    # False Positives vs False Negatives
    errors = figure(plot_width=plot_width,
                    plot_height=plot_height,
                    sizing_mode='fixed',
                    x_range=(multiplier.min(), multiplier.max()),
                    y_range=(0, false_positives.max()),
                    title='False Positives vs False Negatives',
                    x_axis_label='Multiplier',
                    y_axis_label='Count')

    errors.title.text_font_size = title_font_size
    errors.border_fill_color = cell_bg_color
    errors.border_fill_alpha = cell_bg_alpha
    errors.background_fill_color = cell_bg_color
    errors.background_fill_alpha = plot_bg_alpha
    errors.min_border_left = left_border
    errors.min_border_right = right_border
    errors.min_border_top = top_border
    errors.min_border_bottom = right_border

    errors.line(multiplier,
                false_positives,
                legend_label='false positives',
                line_width=2,
                color=fp_color)
    errors.line(multiplier,
                false_negatives,
                legend_label='false negatives',
                line_width=2,
                color=fn_color)
    errors.yaxis.formatter = NumeralTickFormatter(format='0,0')

    errors.extra_y_ranges = {"y2": Range1d(start=0, end=1.1)}
    errors.add_layout(
        LinearAxis(y_range_name="y2",
                   axis_label="Score",
                   formatter=NumeralTickFormatter(format='0.00%')), 'right')
    errors.line(multiplier,
                f1_score,
                line_width=2,
                color=f1_color,
                legend_label='F1 Score',
                y_range_name="y2")

    # F1 Score Maximization point
    f1_thresh = Span(location=f1_max,
                     dimension='height',
                     line_color=f1_color,
                     line_dash='dashed',
                     line_width=2)
    f1_label = Label(x=f1_max + .05,
                     y=180,
                     y_units='screen',
                     text=f'F1 Max: {round(f1_max,2)}',
                     text_font_size='10pt',
                     text_font_style='bold',
                     text_align='left',
                     text_color=f1_color)

    errors.add_layout(f1_thresh)
    errors.add_layout(f1_label)

    errors.legend.location = "top_right"
    errors.legend.background_fill_alpha = .3

    # False Negative Weighting.
    # Intro.
    weighting_intro = f'''
    <h3>Error types differ in impact.</h3> 
    <p>In the case of security incidents, a false negative, 
though possibly rarer than false positives, is likely more costly. For example, downtime suffered 
from a DDoS attack (lost sales/customers) incurs more loss than time wasted chasing a false positive 
(labor hours). </p>

<p>Try playing around with the slider to the right to see how your thresholding strategy might need to change 
depending on the relative weight of false negatives to false positives. What does it look like at
1:1, 50:1, etc.?</p>
'''

    weighting_div = Div(text=weighting_intro,
                        width=420,
                        height=180,
                        margin=(0, 75, 0, 0))

    # Now for the weighted errors viz

    default_weighting = 10
    initial_fp_cost = 100
    simulations['weighted_FN'] = simulations.FN * default_weighting
    weighted_fn = simulations.weighted_FN
    simulations[
        'total_weighted_error'] = simulations.FP + simulations.weighted_FN
    total_weighted_error = simulations.total_weighted_error
    simulations['fp_cost'] = initial_fp_cost
    fp_cost = simulations.fp_cost
    simulations[
        'total_estimated_cost'] = simulations.total_weighted_error * simulations.fp_cost
    total_estimated_cost = simulations.total_estimated_cost
    twe_min = simulations[simulations.total_weighted_error ==
                          simulations.total_weighted_error.min()].head(
                              1).squeeze()['multiplier']
    twe_min_count = simulations[simulations.multiplier == twe_min].head(
        1).squeeze()['total_weighted_error']
    generic_twe = simulations[simulations.multiplier.apply(
        lambda x: round(x, 2)) == 3.00].squeeze()['total_weighted_error']

    comparison = f'''
    <p>Based on your inputs, the optimal threshold is around <b>{twe_min}</b>.
    This would result in an estimated <b>{int(twe_min_count):,}</b> total weighted errors and 
    <b>${int(twe_min_count * initial_fp_cost):,}</b> in losses.</p>

    <p>The generic threshold of 3.0 standard deviations would result in <b>{int(generic_twe):,}</b> 
    total weighted errors and <b>${int(generic_twe * initial_fp_cost):,}</b> in losses.</p>

    <p>Using the optimal threshold would save <b>${int((generic_twe - twe_min_count) * initial_fp_cost):,}</b>, 
    reducing costs by <b>{(generic_twe - twe_min_count) / generic_twe * 100:.1f}%</b> 
    (assuming near-future events are distributed similarly to those from the past).</p>
    '''
    comparison_div = Div(text=comparison,
                         width=420,
                         height=230,
                         margin=(0, 75, 0, 0))

    loss_min = ColumnDataSource(data=dict(multiplier=multiplier,
                                          fp=false_positives,
                                          fn=false_negatives,
                                          weighted_fn=weighted_fn,
                                          twe=total_weighted_error,
                                          fpc=fp_cost,
                                          tec=total_estimated_cost,
                                          precision=precision,
                                          recall=recall,
                                          f1=f1_score))

    evaluation = Figure(plot_width=900,
                        plot_height=520,
                        sizing_mode='fixed',
                        x_range=(multiplier.min(), multiplier.max()),
                        title='Evaluation Metrics vs Total Estimated Cost',
                        x_axis_label='Multiplier',
                        y_axis_label='Cost')

    evaluation.title.text_font_size = title_font_size
    evaluation.border_fill_color = cell_bg_color
    evaluation.border_fill_alpha = cell_bg_alpha
    evaluation.background_fill_color = cell_bg_color
    evaluation.background_fill_alpha = plot_bg_alpha
    evaluation.min_border_left = left_border
    evaluation.min_border_right = right_border
    evaluation.min_border_top = top_border
    evaluation.min_border_bottom = bottom_border

    evaluation.line('multiplier',
                    'tec',
                    source=loss_min,
                    line_width=3,
                    line_alpha=0.6,
                    color=total_weighted_color,
                    legend_label='Total Estimated Cost')
    evaluation.yaxis.formatter = NumeralTickFormatter(format='$0,0')

    # Evaluation metrics on second right axis.
    evaluation.extra_y_ranges = {"y2": Range1d(start=0, end=1.1)}

    evaluation.add_layout(
        LinearAxis(y_range_name="y2",
                   axis_label="Score",
                   formatter=NumeralTickFormatter(format='0.00%')), 'right')
    evaluation.line('multiplier',
                    'precision',
                    source=loss_min,
                    line_width=3,
                    line_alpha=0.6,
                    color=precision_color,
                    legend_label='Precision',
                    y_range_name="y2")
    evaluation.line('multiplier',
                    'recall',
                    source=loss_min,
                    line_width=3,
                    line_alpha=0.6,
                    color=recall_color,
                    legend_label='Recall',
                    y_range_name="y2")
    evaluation.line('multiplier',
                    'f1',
                    source=loss_min,
                    line_width=3,
                    line_alpha=0.6,
                    color=f1_color,
                    legend_label='F1 score',
                    y_range_name="y2")
    evaluation.legend.location = "bottom_right"
    evaluation.legend.background_fill_alpha = .3

    twe_thresh = Span(location=twe_min,
                      dimension='height',
                      line_color=total_weighted_color,
                      line_dash='dashed',
                      line_width=2)
    twe_label = Label(x=twe_min - .05,
                      y=240,
                      y_units='screen',
                      text=f'Cost Min: {round(twe_min,2)}',
                      text_font_size='10pt',
                      text_font_style='bold',
                      text_align='right',
                      text_color=total_weighted_color)
    evaluation.add_layout(twe_thresh)
    evaluation.add_layout(twe_label)

    # Add in same f1 thresh as previous viz
    evaluation.add_layout(f1_thresh)
    evaluation.add_layout(f1_label)

    handler = CustomJS(args=dict(source=loss_min,
                                 thresh=twe_thresh,
                                 label=twe_label,
                                 comparison=comparison_div),
                       code="""
       var data = source.data
       var ratio = cb_obj.value
       var multiplier = data['multiplier']
       var fp = data['fp']
       var fn = data['fn']
       var weighted_fn = data['weighted_fn']
       var twe = data['twe']
       var fpc = data['fpc']
       var tec = data['tec']
       var generic_twe = 0
       
       function round(value, decimals) {
       return Number(Math.round(value+'e'+decimals)+'e-'+decimals);
       }
       
       function comma_sep(x) {
           return x.toString().replace(/\B(?<!\.\d*)(?=(\d{3})+(?!\d))/g, ",");
       }
       
       for (var i = 0; i < multiplier.length; i++) {
          weighted_fn[i] = Math.round(fn[i] * ratio)
          twe[i] = weighted_fn[i] + fp[i]
          tec[i] = twe[i] * fpc[i]
          if (round(multiplier[i],2) == 3.00) {
            generic_twe = twe[i]
          }
       }
              
       var min_loss = Math.min.apply(null,twe)
       var new_thresh = 0
       
       for (var i = 0; i < multiplier.length; i++) {
       if (twe[i] == min_loss) {
           new_thresh = multiplier[i]
           thresh.location = new_thresh
           thresh.change.emit()
           label.x = new_thresh
           label.text = `Cost Min: ${new_thresh}`
           label.change.emit()
           comparison.text = `
            <p>Based on your inputs, the optimal threshold is around <b>${new_thresh}</b>.
            This would result in an estimated <b>${comma_sep(round(min_loss,0))}</b> total weighted errors and 
            <b>$${comma_sep(round(min_loss * fpc[i],0))}</b> in losses.</p>
        
            <p>The generic threshold of 3.0 standard deviations would result in <b>${comma_sep(round(generic_twe,0))}</b> 
            total weighted errors and <b>$${comma_sep(round(generic_twe * fpc[i],0))}</b> in losses.</p>
        
            <p>Using the optimal threshold would save <b>$${comma_sep(round((generic_twe - min_loss) * fpc[i],0))}</b>, 
            reducing costs by <b>${comma_sep(round((generic_twe - min_loss) / generic_twe * 100,0))}%</b> 
            (assuming near-future events are distributed similarly to those from the past).</p>
           `
           comparison.change.emit()
         }
       }
       source.change.emit();
    """)

    slider = Slider(start=1.0,
                    end=500,
                    value=default_weighting,
                    step=.25,
                    title="FN:FP Ratio",
                    bar_color='#FFD100',
                    height=50,
                    margin=(5, 0, 5, 0))
    slider.js_on_change('value', handler)

    cost_handler = CustomJS(args=dict(source=loss_min,
                                      comparison=comparison_div),
                            code="""
           var data = source.data
           var new_cost = cb_obj.value
           var multiplier = data['multiplier']
           var fp = data['fp']
           var fn = data['fn']
           var weighted_fn = data['weighted_fn']
           var twe = data['twe']
           var fpc = data['fpc']
           var tec = data['tec']
           var generic_twe = 0
           
           function round(value, decimals) {
           return Number(Math.round(value+'e'+decimals)+'e-'+decimals);
           } 

           function comma_sep(x) {
               return x.toString().replace(/\B(?<!\.\d*)(?=(\d{3})+(?!\d))/g, ",");
           }
           
           for (var i = 0; i < multiplier.length; i++) {
              fpc[i] = new_cost
              tec[i] = twe[i] * fpc[i]
              if (round(multiplier[i],2) == 3.00) {
                generic_twe = twe[i]
              }
           }

           var min_loss = Math.min.apply(null,twe)
           var new_thresh = 0

           for (var i = 0; i < multiplier.length; i++) {
           if (twe[i] == min_loss) {
               new_thresh = multiplier[i]
               comparison.text = `
                <p>Based on your inputs, the optimal threshold is around <b>${new_thresh}</b>.
                This would result in an estimated <b>${comma_sep(round(min_loss,0))}</b> total weighted errors and 
                <b>$${comma_sep(round(min_loss * new_cost,0))}</b> in losses.</p>

                <p>The generic threshold of 3.0 standard deviations would result in 
                <b>${comma_sep(round(generic_twe,0))}</b> total weighted errors and 
                <b>$${comma_sep(round(generic_twe * new_cost,0))}</b> in losses.</p>

                <p>Using the optimal threshold would save 
                <b>$${comma_sep(round((generic_twe - min_loss) * new_cost,0))}</b>, 
                reducing costs by <b>${comma_sep(round((generic_twe - min_loss)/generic_twe * 100,0))}%</b> 
                (assuming near-future events are distributed similarly to those from the past).</p>
               `
               comparison.change.emit()
              }
           }
           source.change.emit();
        """)

    cost_input = TextInput(value=f"{initial_fp_cost}",
                           title="How much a false positive costs:",
                           height=75,
                           margin=(20, 75, 20, 0))
    cost_input.js_on_change('value', cost_handler)

    # Include DataTable of simulation results
    dt_columns = [
        TableColumn(field="multiplier", title="Multiplier"),
        TableColumn(field="fp",
                    title="False Positives",
                    formatter=NumberFormatter(format='0,0')),
        TableColumn(field="fn",
                    title="False Negatives",
                    formatter=NumberFormatter(format='0,0')),
        TableColumn(field="weighted_fn",
                    title="Weighted False Negatives",
                    formatter=NumberFormatter(format='0,0.00')),
        TableColumn(field="twe",
                    title="Total Weighted Errors",
                    formatter=NumberFormatter(format='0,0.00')),
        TableColumn(field="fpc",
                    title="Estimated FP Cost",
                    formatter=NumberFormatter(format='$0,0.00')),
        TableColumn(field="tec",
                    title="Estimated Total Cost",
                    formatter=NumberFormatter(format='$0,0.00')),
        TableColumn(field="precision",
                    title="Precision",
                    formatter=NumberFormatter(format='0.00%')),
        TableColumn(field="recall",
                    title="Recall",
                    formatter=NumberFormatter(format='0.00%')),
        TableColumn(field="f1",
                    title="F1 Score",
                    formatter=NumberFormatter(format='0.00%')),
    ]

    data_table = DataTable(source=loss_min,
                           columns=dt_columns,
                           width=1400,
                           height=700,
                           sizing_mode='fixed',
                           fit_columns=True,
                           reorderable=True,
                           sortable=True,
                           margin=(30, 0, 20, 0))

    # weighting_layout = column([weighting_div, evaluation, slider, data_table])
    weighting_layout = column(
        row(column(weighting_div, cost_input, comparison_div),
            column(slider, evaluation), Div(text='', height=200, width=60)),
        data_table)

    # Initialize visualizations in browser
    time.sleep(1.5)

    layout = grid([
        [title_div],
        [row(summary_div, vertical_line, hypo_div)],
        [
            row(Div(text='', height=200, width=60), exploratory,
                Div(text='', height=200, width=10), overlap_view,
                Div(text='', height=200, width=40))
        ],
        [Div(text='', height=10, width=200)],
        [
            row(Div(text='', height=200, width=60), density,
                Div(text='', height=200, width=10), errors,
                Div(text='', height=200, width=40))
        ],
        [Div(text='', height=10, width=200)],
        [
            row(Div(text='', height=200, width=60), weighting_layout,
                Div(text='', height=200, width=40))
        ],
    ])

    # Generate html resources for dashboard
    fonts = os.path.join(os.getcwd(), 'fonts')
    if os.path.isdir(os.path.join(session_folder, 'fonts')):
        shutil.rmtree(os.path.join(session_folder, 'fonts'))
        shutil.copytree(fonts, os.path.join(session_folder, 'fonts'))
    else:
        shutil.copytree(fonts, os.path.join(session_folder, 'fonts'))

    html = file_html(layout, INLINE, "Cream")
    with open(os.path.join(session_folder, f'{session_name}.html'),
              "w") as file:
        file.write(html)
    webbrowser.open("file://" +
                    os.path.join(session_folder, f'{session_name}.html'))
Exemplo n.º 9
0
def bokeh_ajax(request):
    startDt = dfDict['time'][0].to_pydatetime()
    endDt = dfStatic['time'].iloc[-1].to_pydatetime()
    # note: need the below in order to display the bokeh plot
    jsResources = INLINE.render_js()
    # need the below in order to be able to properly interact with the plot and have the default bokeh plot
    # interaction tool to display
    cssResources = INLINE.render_css()

    source2 = ColumnDataSource(data={"time": [], "temperature": [], "id": []})
    livePlot2 = figure(x_axis_type="datetime",
                       x_range=[startDt, endDt],
                       y_range=(0, 25),
                       y_axis_label='Temperature (Celsius)',
                       title="Sea Surface Temperature at 43.18, -70.43",
                       plot_width=800)
    livePlot2.line("time", "temperature", source=source2)

    updateStartJS = CustomJS(args=dict(plotRange=livePlot2.x_range),
                             code="""
         var newStart = Date.parse(cb_obj.value)
         plotRange.start = newStart
         plotRange.change.emit()
     """)

    updateEndJS = CustomJS(args=dict(plotRange=livePlot2.x_range),
                           code="""
         var newEnd = Date.parse(cb_obj.value)
         plotRange.end = newEnd
         plotRange.change.emit()
     """)

    startInput = TextInput(value=startDt.strftime(dateFmt),
                           title="Enter Date in format: YYYY-mm-dd")
    startInput.js_on_change('value', updateStartJS)
    endInput = TextInput(value=endDt.strftime(dateFmt),
                         title="Enter Date in format: YYYY-mm-dd")
    endInput.js_on_change('value', updateEndJS)
    textWidgets = row(startInput, endInput)

    # https://stackoverflow.com/questions/37083998/flask-bokeh-ajaxdatasource
    # above stackoverflow helped a lot and is what the below CustomJS is based on

    callback = CustomJS(args=dict(source=source2),
                        code="""
        var time_values = "time";
        var temperatures = "temperature";
        var plot_data = source.data;

        jQuery.ajax({
            type: 'POST',
            url: '/AJAXdata2',
            data: {},
            dataType: 'json',
            success: function (json_from_server) {
                plot_data['temperature'] = plot_data['temperature'].concat(json_from_server['temperature']);
                plot_data['time'] = plot_data['time'].concat(json_from_server['time']);
                plot_data['id'] = plot_data['id'].concat(json_from_server['id']);
                source.change.emit();
            },
            error: function() {
                alert("Oh no, something went wrong. Search for an error " +
                      "message in Flask log and browser developer tools.");
            }
        });
        """)

    manualUpdate = Button(label="update graph", callback=callback)
    widgets = widgetbox([manualUpdate])
    # IMPORTANT: key is that the widget you want to control plot X has to be in the same layout object as
    # said plot X . Therefore, when you call the components() method on it both widget and plot live within the
    # object, if they are not then the JS callbacks don't work because I think they do not know how to communicate
    # with one another
    layout2 = column(widgets, textWidgets, livePlot2)
    script2, div2 = components(layout2)
    return {
        'someword': "hello",
        'jsResources': jsResources,
        'cssResources': cssResources,
        'script2': script2,
        'div2': div2
    }
Exemplo n.º 10
0
class ViewerVIWidgets(object):
    """ 
    Encapsulates Bokeh widgets, and related callbacks, used for VI
    """
    
    def __init__(self, title, viewer_cds):
        self.vi_quality_labels = [ x["label"] for x in vi_flags if x["type"]=="quality" ]
        self.vi_issue_labels = [ x["label"] for x in vi_flags if x["type"]=="issue" ]
        self.vi_issue_slabels = [ x["shortlabel"] for x in vi_flags if x["type"]=="issue" ]
        self.js_files = get_resources('js')
        self.title = title
        self.vi_countdown_toggle = None

        #- List of fields to be recorded in output csv file, contains for each field: 
        # [ field name (in VI file header), associated variable in viewer_cds.cds_metadata ]
        self.output_file_fields = []
        for file_field in vi_file_fields:
            if file_field[1] in viewer_cds.cds_metadata.data.keys() :
                self.output_file_fields.append([file_field[0], file_field[1]])

    def add_filename(self, username='******'):
        #- VI file name
        default_vi_filename = "desi-vi_"+self.title
        default_vi_filename += ("_"+username)
        default_vi_filename += ".csv"
        self.vi_filename_input = TextInput(value=default_vi_filename, title="VI file name:")


    def add_vi_issues(self, viewer_cds, widgets):
        #- Optional VI flags (issues)
        self.vi_issue_input = CheckboxGroup(labels=self.vi_issue_labels, active=[])
        vi_issue_code = self.js_files["CSVtoArray.js"] + self.js_files["save_vi.js"]
        vi_issue_code += """
            var issues = []
            for (var i=0; i<vi_issue_labels.length; i++) {
                if (vi_issue_input.active.indexOf(i) >= 0) issues.push(vi_issue_slabels[i])
            }
            if (issues.length > 0) {
                cds_metadata.data['VI_issue_flag'][ifiberslider.value] = ( issues.join('') )
            } else {
                cds_metadata.data['VI_issue_flag'][ifiberslider.value] = " "
            }
            autosave_vi_localStorage(output_file_fields, cds_metadata.data, title)
            cds_metadata.change.emit()
            """
        self.vi_issue_callback = CustomJS(
            args=dict(cds_metadata = viewer_cds.cds_metadata,
                      ifiberslider = widgets.ifiberslider,
                      vi_issue_input = self.vi_issue_input,
                      vi_issue_labels = self.vi_issue_labels,
                      vi_issue_slabels = self.vi_issue_slabels,
                      title = self.title, output_file_fields = self.output_file_fields),
                      code = vi_issue_code )
        self.vi_issue_input.js_on_click(self.vi_issue_callback)


    def add_vi_z(self, viewer_cds, widgets):
    ## TODO: z_tovi behaviour if with_vi_widget=False ..?
        #- Optional VI information on redshift
        self.vi_z_input = TextInput(value='', title="VI redshift:")
        vi_z_code = self.js_files["CSVtoArray.js"] + self.js_files["save_vi.js"]
        vi_z_code += """
            cds_metadata.data['VI_z'][ifiberslider.value]=vi_z_input.value
            autosave_vi_localStorage(output_file_fields, cds_metadata.data, title)
            cds_metadata.change.emit()
            """
        self.vi_z_callback = CustomJS(
            args=dict(cds_metadata = viewer_cds.cds_metadata,
                      ifiberslider = widgets.ifiberslider,
                      vi_z_input = self.vi_z_input,
                      title = self.title, output_file_fields=self.output_file_fields),
                      code = vi_z_code )
        self.vi_z_input.js_on_change('value', self.vi_z_callback)

        # Copy z value from redshift slider to VI
        self.z_tovi_button = Button(label='Copy z to VI')
        self.z_tovi_callback = CustomJS(
            args=dict(z_input=widgets.z_input, vi_z_input=self.vi_z_input),
            code="""
                vi_z_input.value = z_input.value
            """)
        self.z_tovi_button.js_on_event('button_click', self.z_tovi_callback)

    def add_vi_spectype(self, viewer_cds, widgets):
        #- Optional VI information on spectral type
        self.vi_category_select = Select(value=' ', title="VI spectype:", options=([' '] + vi_spectypes))
        # The default value set to ' ' as setting value='' does not seem to work well with Select.
        vi_category_code = self.js_files["CSVtoArray.js"] + self.js_files["save_vi.js"]
        vi_category_code += """
            if (vi_category_select.value == ' ') {
                cds_metadata.data['VI_spectype'][ifiberslider.value]=''
            } else {
                cds_metadata.data['VI_spectype'][ifiberslider.value]=vi_category_select.value
            }
            autosave_vi_localStorage(output_file_fields, cds_metadata.data, title)
            cds_metadata.change.emit()
            """
        self.vi_category_callback = CustomJS(
            args=dict(cds_metadata=viewer_cds.cds_metadata, 
                      ifiberslider = widgets.ifiberslider,
                      vi_category_select=self.vi_category_select,
                      title=self.title, output_file_fields=self.output_file_fields),
            code=vi_category_code )
        self.vi_category_select.js_on_change('value', self.vi_category_callback)


    def add_vi_comment(self, viewer_cds, widgets):
        #- Optional VI comment
        self.vi_comment_input = TextInput(value='', title="VI comment (see guidelines):")
        vi_comment_code = self.js_files["CSVtoArray.js"] + self.js_files["save_vi.js"]
        vi_comment_code += """
            var stored_comment = (vi_comment_input.value).replace(/./g, function(char){
                if ( char==',' ) {
                    return ';'
                } else if ( char.charCodeAt(0)<=127 ) {
                    return char
                } else {
                    var char_list = ['Å','α','β','γ','δ','λ']
                    var char_replace = ['Angstrom','alpha','beta','gamma','delta','lambda']
                    for (var i=0; i<char_list.length; i++) {
                        if ( char==char_list[i] ) return char_replace[i]
                    }
                    return '?'
                }
            })
            cds_metadata.data['VI_comment'][ifiberslider.value] = stored_comment
            autosave_vi_localStorage(output_file_fields, cds_metadata.data, title)
            cds_metadata.change.emit()
            """
        self.vi_comment_callback = CustomJS(
            args=dict(cds_metadata = viewer_cds.cds_metadata,
                      ifiberslider = widgets.ifiberslider, 
                      vi_comment_input = self.vi_comment_input,
                      title=self.title, output_file_fields=self.output_file_fields),
            code=vi_comment_code )
        self.vi_comment_input.js_on_change('value',self.vi_comment_callback)

        #- List of "standard" VI comment
        self.vi_std_comment_select = Select(value=" ", title="Standard comment:", options=([' '] + vi_std_comments))
        vi_std_comment_code = """
            if (vi_std_comment_select.value != ' ') {
                if (vi_comment_input.value != '') {
                    vi_comment_input.value = vi_comment_input.value + " " + vi_std_comment_select.value
                } else {
                    vi_comment_input.value = vi_std_comment_select.value
                }
            }
            """
        self.vi_std_comment_callback = CustomJS(
            args = dict(vi_std_comment_select = self.vi_std_comment_select,
                        vi_comment_input = self.vi_comment_input),
                        code = vi_std_comment_code )
        self.vi_std_comment_select.js_on_change('value', self.vi_std_comment_callback)


    def add_vi_quality(self, viewer_cds, widgets):
        #- Main VI quality widget
        self.vi_quality_input = RadioButtonGroup(labels=self.vi_quality_labels)
        vi_quality_code = self.js_files["CSVtoArray.js"] + self.js_files["save_vi.js"]
        vi_quality_code += """
            if ( vi_quality_input.active >= 0 ) {
                cds_metadata.data['VI_quality_flag'][ifiberslider.value] = vi_quality_labels[vi_quality_input.active]
            } else {
                cds_metadata.data['VI_quality_flag'][ifiberslider.value] = "-1"
            }
            autosave_vi_localStorage(output_file_fields, cds_metadata.data, title)
            cds_metadata.change.emit()
        """
        self.vi_quality_callback = CustomJS(
            args = dict(cds_metadata = viewer_cds.cds_metadata,
                        vi_quality_input = self.vi_quality_input,
                        vi_quality_labels = self.vi_quality_labels,
                        ifiberslider = widgets.ifiberslider,
                        title=self.title, output_file_fields = self.output_file_fields),
            code=vi_quality_code )
        self.vi_quality_input.js_on_click(self.vi_quality_callback)

    def add_vi_scanner(self, viewer_cds):
        #- VI scanner name
        self.vi_name_input = TextInput(value=(viewer_cds.cds_metadata.data['VI_scanner'][0]).strip(), title="Your name (3-letter acronym):")
        vi_name_code = self.js_files["CSVtoArray.js"] + self.js_files["save_vi.js"]
        vi_name_code += """
            for (var i=0; i<(cds_metadata.data['VI_scanner']).length; i++) {
                cds_metadata.data['VI_scanner'][i]=vi_name_input.value
            }
            var newname = vi_filename_input.value
            var name_chunks = newname.split("_")
            newname = ( name_chunks.slice(0,name_chunks.length-1).join("_") ) + ("_"+vi_name_input.value+".csv")
            vi_filename_input.value = newname
            autosave_vi_localStorage(output_file_fields, cds_metadata.data, title)
            """
        self.vi_name_callback = CustomJS(
            args = dict(cds_metadata = viewer_cds.cds_metadata,
                        vi_name_input = self.vi_name_input,
                        vi_filename_input = self.vi_filename_input, title=self.title,
                        output_file_fields = self.output_file_fields),
            code=vi_name_code )
        self.vi_name_input.js_on_change('value', self.vi_name_callback)

    def add_guidelines(self):
        #- Guidelines for VI flags
        vi_guideline_txt = "<B> VI guidelines </B>"
        vi_guideline_txt += "<BR /> <B> Classification flags: </B>"
        for flag in vi_flags :
            if flag['type'] == 'quality' : vi_guideline_txt += ("<BR />&emsp;&emsp;[&emsp;"+flag['label']+"&emsp;] "+flag['description'])
        vi_guideline_txt += "<BR /> <B> Optional indications: </B>"
        for flag in vi_flags :
            if flag['type'] == 'issue' :
                vi_guideline_txt += ( "<BR />&emsp;&emsp;[&emsp;" + flag['label'] +
                                     "&emsp;(" + flag['shortlabel'] + ")&emsp;] " + flag['description'] )
        vi_guideline_txt += "<BR /> <B> Comments: </B> <BR /> 100 characters max, avoid commas (automatically replaced by semi-columns), ASCII only."
        self.vi_guideline_div = Div(text=vi_guideline_txt)

    def add_vi_storage(self, viewer_cds, widgets):
        #- Save VI info to CSV file
        self.save_vi_button = Button(label="Download VI", button_type="success")
        save_vi_code = self.js_files["FileSaver.js"] + self.js_files["CSVtoArray.js"] + self.js_files["save_vi.js"]
        save_vi_code += """
            download_vi_file(output_file_fields, cds_metadata.data, vi_filename_input.value)
            """
        self.save_vi_callback = CustomJS(
            args = dict(cds_metadata = viewer_cds.cds_metadata,
                        output_file_fields = self.output_file_fields,
                        vi_filename_input = self.vi_filename_input),
                        code=save_vi_code )
        self.save_vi_button.js_on_event('button_click', self.save_vi_callback)

        #- Recover auto-saved VI data in browser
        self.recover_vi_button = Button(label="Recover auto-saved VI", button_type="default")
        recover_vi_code = self.js_files["CSVtoArray.js"] + self.js_files["recover_autosave_vi.js"]
        self.recover_vi_callback = CustomJS(
            args = dict(title=self.title, output_file_fields=self.output_file_fields,
                        cds_metadata = viewer_cds.cds_metadata,
                        ifiber = widgets.ifiberslider.value, vi_comment_input = self.vi_comment_input,
                        vi_name_input=self.vi_name_input, vi_quality_input=self.vi_quality_input,
                        vi_issue_input=self.vi_issue_input,
                        vi_issue_slabels=self.vi_issue_slabels, vi_quality_labels=self.vi_quality_labels),
                        code = recover_vi_code )
        self.recover_vi_button.js_on_event('button_click', self.recover_vi_callback)

        #- Clear all auto-saved VI
        self.clear_vi_button = Button(label="Clear all auto-saved VI", button_type="default")
        self.clear_vi_callback = CustomJS( args = dict(), code = """
            localStorage.clear()
            """ )
        self.clear_vi_button.js_on_event('button_click', self.clear_vi_callback)

    def add_vi_table(self, viewer_cds):
        #- Show VI in a table
        vi_table_columns = [
            TableColumn(field="VI_quality_flag", title="Flag", width=40),
            TableColumn(field="VI_issue_flag", title="Opt.", width=50),
            TableColumn(field="VI_z", title="VI z", width=50),
            TableColumn(field="VI_spectype", title="VI spectype", width=150),
            TableColumn(field="VI_comment", title="VI comment", width=200)
        ]
        self.vi_table = DataTable(source=viewer_cds.cds_metadata, columns=vi_table_columns, width=500)
        self.vi_table.height = 10 * self.vi_table.row_height


    def add_countdown(self, vi_countdown):
        #- VI countdown
        assert vi_countdown > 0
        self.vi_countdown_callback = CustomJS(args=dict(vi_countdown=vi_countdown), code="""
            if ( (cb_obj.label).includes('Start') ) { // Callback doesn't do anything after countdown started
                var countDownDate = new Date().getTime() + (1000 * 60 * vi_countdown);
                var countDownLoop = setInterval(function(){
                    var now = new Date().getTime();
                    var distance = countDownDate - now;
                    if (distance<0) {
                        cb_obj.label = "Time's up !";
                        clearInterval(countDownLoop);
                    } else {
                        var days = Math.floor(distance / (1000 * 60 * 60 * 24));
                        var hours = Math.floor((distance % (1000 * 60 * 60 * 24)) / (1000 * 60 * 60));
                        var minutes = Math.floor((distance % (1000 * 60 * 60)) / (1000 * 60));
                        var seconds = Math.floor((distance % (1000 * 60)) / 1000);
                        //var stuff = days + "d " + hours + "h " + minutes + "m " + seconds + "s ";
                        var stuff = minutes + "m " + seconds + "s ";
                        cb_obj.label = "Countdown: " + stuff;
                    }
                }, 1000);
            }
        """)
        self.vi_countdown_toggle = Toggle(label='Start countdown ('+str(vi_countdown)+' min)', active=False, button_type="success")
        self.vi_countdown_toggle.js_on_change('active', self.vi_countdown_callback)
Exemplo n.º 11
0
def photometry_plot(obj_id, user, width=600, device="browser"):
    """Create object photometry scatter plot.

    Parameters
    ----------
    obj_id : str
        ID of Obj to be plotted.

    Returns
    -------
    dict
        Returns Bokeh JSON embedding for the desired plot.
    """

    data = pd.read_sql(
        DBSession()
        .query(
            Photometry,
            Telescope.nickname.label("telescope"),
            Instrument.name.label("instrument"),
        )
        .join(Instrument, Instrument.id == Photometry.instrument_id)
        .join(Telescope, Telescope.id == Instrument.telescope_id)
        .filter(Photometry.obj_id == obj_id)
        .filter(
            Photometry.groups.any(Group.id.in_([g.id for g in user.accessible_groups]))
        )
        .statement,
        DBSession().bind,
    )

    if data.empty:
        return None, None, None

    # get spectra to annotate on phot plots
    spectra = (
        Spectrum.query_records_accessible_by(user)
        .filter(Spectrum.obj_id == obj_id)
        .all()
    )

    data['color'] = [get_color(f) for f in data['filter']]

    # get marker for each unique instrument
    instruments = list(data.instrument.unique())
    markers = []
    for i, inst in enumerate(instruments):
        markers.append(phot_markers[i % len(phot_markers)])

    filters = list(set(data['filter']))
    colors = [get_color(f) for f in filters]

    color_mapper = CategoricalColorMapper(factors=filters, palette=colors)
    color_dict = {'field': 'filter', 'transform': color_mapper}

    labels = []
    for i, datarow in data.iterrows():
        label = f'{datarow["instrument"]}/{datarow["filter"]}'
        if datarow['origin'] is not None:
            label += f'/{datarow["origin"]}'
        labels.append(label)

    data['label'] = labels
    data['zp'] = PHOT_ZP
    data['magsys'] = 'ab'
    data['alpha'] = 1.0
    data['lim_mag'] = (
        -2.5 * np.log10(data['fluxerr'] * PHOT_DETECTION_THRESHOLD) + data['zp']
    )

    # Passing a dictionary to a bokeh datasource causes the frontend to die,
    # deleting the dictionary column fixes that
    del data['original_user_data']

    # keep track of things that are only upper limits
    data['hasflux'] = ~data['flux'].isna()

    # calculate the magnitudes - a photometry point is considered "significant"
    # or "detected" (and thus can be represented by a magnitude) if its snr
    # is above PHOT_DETECTION_THRESHOLD
    obsind = data['hasflux'] & (
        data['flux'].fillna(0.0) / data['fluxerr'] >= PHOT_DETECTION_THRESHOLD
    )
    data.loc[~obsind, 'mag'] = None
    data.loc[obsind, 'mag'] = -2.5 * np.log10(data[obsind]['flux']) + PHOT_ZP

    # calculate the magnitude errors using standard error propagation formulae
    # https://en.wikipedia.org/wiki/Propagation_of_uncertainty#Example_formulae
    data.loc[~obsind, 'magerr'] = None
    coeff = 2.5 / np.log(10)
    magerrs = np.abs(coeff * data[obsind]['fluxerr'] / data[obsind]['flux'])
    data.loc[obsind, 'magerr'] = magerrs
    data['obs'] = obsind
    data['stacked'] = False

    split = data.groupby('label', sort=False)

    finite = np.isfinite(data['flux'])
    fdata = data[finite]
    lower = np.min(fdata['flux']) * 0.95
    upper = np.max(fdata['flux']) * 1.05

    xmin = data['mjd'].min() - 2
    xmax = data['mjd'].max() + 2

    # Layout parameters based on device type
    active_drag = None if "mobile" in device or "tablet" in device else "box_zoom"
    tools = (
        'box_zoom,pan,reset'
        if "mobile" in device or "tablet" in device
        else "box_zoom,wheel_zoom,pan,reset,save"
    )
    legend_loc = "below" if "mobile" in device or "tablet" in device else "right"
    legend_orientation = (
        "vertical" if device in ["browser", "mobile_portrait"] else "horizontal"
    )

    # Compute a plot component height based on rough number of legend rows added below the plot
    # Values are based on default sizing of bokeh components and an estimate of how many
    # legend items would fit on the average device screen. Note that the legend items per
    # row is computed more exactly later once labels are extracted from the data (with the
    # add_plot_legend() function).
    #
    # The height is manually computed like this instead of using built in aspect_ratio/sizing options
    # because with the new Interactive Legend approach (instead of the legacy CheckboxLegendGroup), the
    # Legend component is considered part of the plot and plays into the sizing computations. Since the
    # number of items in the legend can alter the needed heights of the plot, using built-in Bokeh options
    # for sizing does not allow for keeping the actual graph part of the plot at a consistent aspect ratio.
    #
    # For the frame width, by default we take the desired plot width minus 64 for the y-axis/label taking
    # up horizontal space
    frame_width = width - 64
    if device == "mobile_portrait":
        legend_items_per_row = 1
        legend_row_height = 24
        aspect_ratio = 1
    elif device == "mobile_landscape":
        legend_items_per_row = 4
        legend_row_height = 50
        aspect_ratio = 1.8
    elif device == "tablet_portrait":
        legend_items_per_row = 5
        legend_row_height = 50
        aspect_ratio = 1.5
    elif device == "tablet_landscape":
        legend_items_per_row = 7
        legend_row_height = 50
        aspect_ratio = 1.8
    elif device == "browser":
        # Width minus some base width for the legend, which is only a column to the right
        # for browser mode
        frame_width = width - 200

    height = (
        500
        if device == "browser"
        else math.floor(width / aspect_ratio)
        + legend_row_height * int(len(split) / legend_items_per_row)
        + 30  # 30 is the height of the toolbar
    )

    plot = figure(
        frame_width=frame_width,
        height=height,
        active_drag=active_drag,
        tools=tools,
        toolbar_location='above',
        toolbar_sticky=True,
        y_range=(lower, upper),
        min_border_right=16,
        x_axis_location='above',
        sizing_mode="stretch_width",
    )

    plot.xaxis.axis_label = 'MJD'
    now = Time.now().mjd
    plot.extra_x_ranges = {"Days Ago": Range1d(start=now - xmin, end=now - xmax)}
    plot.add_layout(LinearAxis(x_range_name="Days Ago", axis_label="Days Ago"), 'below')

    imhover = HoverTool(tooltips=tooltip_format)
    imhover.renderers = []
    plot.add_tools(imhover)

    model_dict = {}
    legend_items = []
    for i, (label, sdf) in enumerate(split):
        renderers = []

        # for the flux plot, we only show things that have a flux value
        df = sdf[sdf['hasflux']]

        key = f'obs{i}'
        model_dict[key] = plot.scatter(
            x='mjd',
            y='flux',
            color='color',
            marker=factor_mark('instrument', markers, instruments),
            fill_color=color_dict,
            alpha='alpha',
            source=ColumnDataSource(df),
        )
        renderers.append(model_dict[key])
        imhover.renderers.append(model_dict[key])

        key = f'bin{i}'
        model_dict[key] = plot.scatter(
            x='mjd',
            y='flux',
            color='color',
            marker=factor_mark('instrument', markers, instruments),
            fill_color=color_dict,
            source=ColumnDataSource(
                data=dict(
                    mjd=[],
                    flux=[],
                    fluxerr=[],
                    filter=[],
                    color=[],
                    lim_mag=[],
                    mag=[],
                    magerr=[],
                    stacked=[],
                    instrument=[],
                )
            ),
        )
        renderers.append(model_dict[key])
        imhover.renderers.append(model_dict[key])

        key = 'obserr' + str(i)
        y_err_x = []
        y_err_y = []

        for d, ro in df.iterrows():
            px = ro['mjd']
            py = ro['flux']
            err = ro['fluxerr']

            y_err_x.append((px, px))
            y_err_y.append((py - err, py + err))

        model_dict[key] = plot.multi_line(
            xs='xs',
            ys='ys',
            color='color',
            alpha='alpha',
            source=ColumnDataSource(
                data=dict(
                    xs=y_err_x, ys=y_err_y, color=df['color'], alpha=[1.0] * len(df)
                )
            ),
        )
        renderers.append(model_dict[key])

        key = f'binerr{i}'
        model_dict[key] = plot.multi_line(
            xs='xs',
            ys='ys',
            color='color',
            # legend_label=label,
            source=ColumnDataSource(data=dict(xs=[], ys=[], color=[])),
        )
        renderers.append(model_dict[key])

        legend_items.append(LegendItem(label=label, renderers=renderers))

    if device == "mobile_portrait":
        plot.xaxis.ticker.desired_num_ticks = 5
    plot.yaxis.axis_label = 'Flux (μJy)'
    plot.toolbar.logo = None

    add_plot_legend(plot, legend_items, width, legend_orientation, legend_loc)
    slider = Slider(
        start=0.0,
        end=15.0,
        value=0.0,
        step=1.0,
        title='Binsize (days)',
        max_width=350,
        margin=(4, 10, 0, 10),
    )

    callback = CustomJS(
        args={'slider': slider, 'n_labels': len(split), **model_dict},
        code=open(
            os.path.join(os.path.dirname(__file__), '../static/js/plotjs', 'stackf.js')
        )
        .read()
        .replace('default_zp', str(PHOT_ZP))
        .replace('detect_thresh', str(PHOT_DETECTION_THRESHOLD)),
    )

    slider.js_on_change('value', callback)

    # Mark the first and last detections
    detection_dates = data[data['hasflux']]['mjd']
    if len(detection_dates) > 0:
        first = round(detection_dates.min(), 6)
        last = round(detection_dates.max(), 6)
        first_color = "#34b4eb"
        last_color = "#8992f5"
        midpoint = (upper + lower) / 2
        line_top = 5 * upper - 4 * midpoint
        line_bottom = 5 * lower - 4 * midpoint
        y = np.linspace(line_bottom, line_top, num=5000)
        first_r = plot.line(
            x=np.full(5000, first),
            y=y,
            line_alpha=0.5,
            line_color=first_color,
            line_width=2,
        )
        plot.add_tools(
            HoverTool(
                tooltips=[("First detection", f'{first}')],
                renderers=[first_r],
            )
        )
        last_r = plot.line(
            x=np.full(5000, last),
            y=y,
            line_alpha=0.5,
            line_color=last_color,
            line_width=2,
        )
        plot.add_tools(
            HoverTool(
                tooltips=[("Last detection", f'{last}')],
                renderers=[last_r],
            )
        )

    # Mark when spectra were taken
    annotate_spec(plot, spectra, lower, upper)
    layout = column(slider, plot, width=width, height=height)

    p1 = Panel(child=layout, title='Flux')

    # now make the mag light curve
    ymax = (
        np.nanmax(
            (
                np.nanmax(data.loc[obsind, 'mag']) if any(obsind) else np.nan,
                np.nanmax(data.loc[~obsind, 'lim_mag']) if any(~obsind) else np.nan,
            )
        )
        + 0.1
    )
    ymin = (
        np.nanmin(
            (
                np.nanmin(data.loc[obsind, 'mag']) if any(obsind) else np.nan,
                np.nanmin(data.loc[~obsind, 'lim_mag']) if any(~obsind) else np.nan,
            )
        )
        - 0.1
    )

    plot = figure(
        frame_width=frame_width,
        height=height,
        active_drag=active_drag,
        tools=tools,
        y_range=(ymax, ymin),
        x_range=(xmin, xmax),
        toolbar_location='above',
        toolbar_sticky=True,
        x_axis_location='above',
        sizing_mode="stretch_width",
    )

    plot.xaxis.axis_label = 'MJD'
    now = Time.now().mjd
    plot.extra_x_ranges = {"Days Ago": Range1d(start=now - xmin, end=now - xmax)}
    plot.add_layout(LinearAxis(x_range_name="Days Ago", axis_label="Days Ago"), 'below')

    obj = DBSession().query(Obj).get(obj_id)
    if obj.dm is not None:
        plot.extra_y_ranges = {
            "Absolute Mag": Range1d(start=ymax - obj.dm, end=ymin - obj.dm)
        }
        plot.add_layout(
            LinearAxis(y_range_name="Absolute Mag", axis_label="m - DM"), 'right'
        )

    # Mark the first and last detections again
    detection_dates = data[obsind]['mjd']
    if len(detection_dates) > 0:
        first = round(detection_dates.min(), 6)
        last = round(detection_dates.max(), 6)
        midpoint = (ymax + ymin) / 2
        line_top = 5 * ymax - 4 * midpoint
        line_bottom = 5 * ymin - 4 * midpoint
        y = np.linspace(line_bottom, line_top, num=5000)
        first_r = plot.line(
            x=np.full(5000, first),
            y=y,
            line_alpha=0.5,
            line_color=first_color,
            line_width=2,
        )
        plot.add_tools(
            HoverTool(
                tooltips=[("First detection", f'{first}')],
                renderers=[first_r],
            )
        )
        last_r = plot.line(
            x=np.full(5000, last),
            y=y,
            line_alpha=0.5,
            line_color=last_color,
            line_width=2,
        )
        plot.add_tools(
            HoverTool(
                tooltips=[("Last detection", f'{last}')],
                renderers=[last_r],
                point_policy='follow_mouse',
            )
        )

    # Mark when spectra were taken
    annotate_spec(plot, spectra, ymax, ymin)

    imhover = HoverTool(tooltips=tooltip_format)
    imhover.renderers = []
    plot.add_tools(imhover)

    model_dict = {}

    # Legend items are individually stored instead of being applied
    # directly when plotting so that they can be separated into multiple
    # Legend() components if needed (to simulate horizontal row wrapping).
    # This is necessary because Bokeh does not support row wrapping with
    # horizontally-oriented legends out-of-the-box.
    legend_items = []
    for i, (label, df) in enumerate(split):
        renderers = []

        key = f'obs{i}'
        model_dict[key] = plot.scatter(
            x='mjd',
            y='mag',
            color='color',
            marker=factor_mark('instrument', markers, instruments),
            fill_color=color_dict,
            alpha='alpha',
            source=ColumnDataSource(df[df['obs']]),
        )
        renderers.append(model_dict[key])
        imhover.renderers.append(model_dict[key])

        unobs_source = df[~df['obs']].copy()
        unobs_source.loc[:, 'alpha'] = 0.8

        key = f'unobs{i}'
        model_dict[key] = plot.scatter(
            x='mjd',
            y='lim_mag',
            color=color_dict,
            marker='inverted_triangle',
            fill_color='white',
            line_color='color',
            alpha='alpha',
            source=ColumnDataSource(unobs_source),
        )
        renderers.append(model_dict[key])
        imhover.renderers.append(model_dict[key])

        key = f'bin{i}'
        model_dict[key] = plot.scatter(
            x='mjd',
            y='mag',
            color=color_dict,
            marker=factor_mark('instrument', markers, instruments),
            fill_color='color',
            source=ColumnDataSource(
                data=dict(
                    mjd=[],
                    flux=[],
                    fluxerr=[],
                    filter=[],
                    color=[],
                    lim_mag=[],
                    mag=[],
                    magerr=[],
                    instrument=[],
                    stacked=[],
                )
            ),
        )
        renderers.append(model_dict[key])
        imhover.renderers.append(model_dict[key])

        key = 'obserr' + str(i)
        y_err_x = []
        y_err_y = []

        for d, ro in df[df['obs']].iterrows():
            px = ro['mjd']
            py = ro['mag']
            err = ro['magerr']

            y_err_x.append((px, px))
            y_err_y.append((py - err, py + err))

        model_dict[key] = plot.multi_line(
            xs='xs',
            ys='ys',
            color='color',
            alpha='alpha',
            source=ColumnDataSource(
                data=dict(
                    xs=y_err_x,
                    ys=y_err_y,
                    color=df[df['obs']]['color'],
                    alpha=[1.0] * len(df[df['obs']]),
                )
            ),
        )
        renderers.append(model_dict[key])

        key = f'binerr{i}'
        model_dict[key] = plot.multi_line(
            xs='xs',
            ys='ys',
            color='color',
            source=ColumnDataSource(data=dict(xs=[], ys=[], color=[])),
        )
        renderers.append(model_dict[key])

        key = f'unobsbin{i}'
        model_dict[key] = plot.scatter(
            x='mjd',
            y='lim_mag',
            color='color',
            marker='inverted_triangle',
            fill_color='white',
            line_color=color_dict,
            alpha=0.8,
            source=ColumnDataSource(
                data=dict(
                    mjd=[],
                    flux=[],
                    fluxerr=[],
                    filter=[],
                    color=[],
                    lim_mag=[],
                    mag=[],
                    magerr=[],
                    instrument=[],
                    stacked=[],
                )
            ),
        )
        imhover.renderers.append(model_dict[key])
        renderers.append(model_dict[key])

        key = f'all{i}'
        model_dict[key] = ColumnDataSource(df)

        key = f'bold{i}'
        model_dict[key] = ColumnDataSource(
            df[
                [
                    'mjd',
                    'flux',
                    'fluxerr',
                    'mag',
                    'magerr',
                    'filter',
                    'zp',
                    'magsys',
                    'lim_mag',
                    'stacked',
                ]
            ]
        )

        legend_items.append(LegendItem(label=label, renderers=renderers))

    add_plot_legend(plot, legend_items, width, legend_orientation, legend_loc)

    plot.yaxis.axis_label = 'AB mag'
    plot.toolbar.logo = None

    slider = Slider(
        start=0.0,
        end=15.0,
        value=0.0,
        step=1.0,
        title='Binsize (days)',
        max_width=350,
        margin=(4, 10, 0, 10),
    )

    button = Button(label="Export Bold Light Curve to CSV")
    button.js_on_click(
        CustomJS(
            args={'slider': slider, 'n_labels': len(split), **model_dict},
            code=open(
                os.path.join(
                    os.path.dirname(__file__), '../static/js/plotjs', "download.js"
                )
            )
            .read()
            .replace('objname', obj_id)
            .replace('default_zp', str(PHOT_ZP)),
        )
    )

    # Don't need to expose CSV download on mobile
    top_layout = (
        slider if "mobile" in device or "tablet" in device else row(slider, button)
    )

    callback = CustomJS(
        args={'slider': slider, 'n_labels': len(split), **model_dict},
        code=open(
            os.path.join(os.path.dirname(__file__), '../static/js/plotjs', 'stackm.js')
        )
        .read()
        .replace('default_zp', str(PHOT_ZP))
        .replace('detect_thresh', str(PHOT_DETECTION_THRESHOLD)),
    )
    slider.js_on_change('value', callback)
    layout = column(top_layout, plot, width=width, height=height)

    p2 = Panel(child=layout, title='Mag')

    # now make period plot

    # get periods from annotations
    annotation_list = obj.get_annotations_readable_by(user)
    period_labels = []
    period_list = []
    for an in annotation_list:
        if 'period' in an.data:
            period_list.append(an.data['period'])
            period_labels.append(an.origin + ": %.9f" % an.data['period'])

    if len(period_list) > 0:
        period = period_list[0]
    else:
        period = None
    # don't generate if no period annotated
    if period is not None:
        # bokeh figure for period plotting
        period_plot = figure(
            frame_width=frame_width,
            height=height,
            active_drag=active_drag,
            tools=tools,
            y_range=(ymax, ymin),
            x_range=(-0.01, 2.01),  # initially one phase
            toolbar_location='above',
            toolbar_sticky=False,
            x_axis_location='below',
            sizing_mode="stretch_width",
        )

        # axis labels
        period_plot.xaxis.axis_label = 'phase'
        period_plot.yaxis.axis_label = 'mag'
        period_plot.toolbar.logo = None

        # do we have a distance modulus (dm)?
        obj = DBSession().query(Obj).get(obj_id)
        if obj.dm is not None:
            period_plot.extra_y_ranges = {
                "Absolute Mag": Range1d(start=ymax - obj.dm, end=ymin - obj.dm)
            }
            period_plot.add_layout(
                LinearAxis(y_range_name="Absolute Mag", axis_label="m - DM"), 'right'
            )

        # initiate hover tool
        period_imhover = HoverTool(tooltips=tooltip_format)
        period_imhover.renderers = []
        period_plot.add_tools(period_imhover)

        # initiate period radio buttons
        period_selection = RadioGroup(labels=period_labels, active=0)

        phase_selection = RadioGroup(labels=["One phase", "Two phases"], active=1)

        # store all the plot data
        period_model_dict = {}

        # iterate over each filter
        legend_items = []
        for i, (label, df) in enumerate(split):
            renderers = []
            # fold x-axis on period in days
            df['mjd_folda'] = (df['mjd'] % period) / period
            df['mjd_foldb'] = df['mjd_folda'] + 1.0

            # phase plotting
            for ph in ['a', 'b']:
                key = 'fold' + ph + f'{i}'
                period_model_dict[key] = period_plot.scatter(
                    x='mjd_fold' + ph,
                    y='mag',
                    color='color',
                    marker=factor_mark('instrument', markers, instruments),
                    fill_color=color_dict,
                    alpha='alpha',
                    # visible=('a' in ph),
                    source=ColumnDataSource(df[df['obs']]),  # only visible data
                )
                # add to hover tool
                period_imhover.renderers.append(period_model_dict[key])
                renderers.append(period_model_dict[key])

                # errorbars for phases
                key = 'fold' + ph + f'err{i}'
                y_err_x = []
                y_err_y = []

                # get each visible error value
                for d, ro in df[df['obs']].iterrows():
                    px = ro['mjd_fold' + ph]
                    py = ro['mag']
                    err = ro['magerr']
                    # set up error tuples
                    y_err_x.append((px, px))
                    y_err_y.append((py - err, py + err))
                # plot phase errors
                period_model_dict[key] = period_plot.multi_line(
                    xs='xs',
                    ys='ys',
                    color='color',
                    alpha='alpha',
                    # visible=('a' in ph),
                    source=ColumnDataSource(
                        data=dict(
                            xs=y_err_x,
                            ys=y_err_y,
                            color=df[df['obs']]['color'],
                            alpha=[1.0] * len(df[df['obs']]),
                        )
                    ),
                )
                renderers.append(period_model_dict[key])

            legend_items.append(LegendItem(label=label, renderers=renderers))

        add_plot_legend(
            period_plot, legend_items, width, legend_orientation, legend_loc
        )

        # set up period adjustment text box
        period_title = Div(text="Period (days): ")
        period_textinput = TextInput(value=str(period if period is not None else 0.0))
        period_textinput.js_on_change(
            'value',
            CustomJS(
                args={
                    'textinput': period_textinput,
                    'numphases': phase_selection,
                    'n_labels': len(split),
                    'p': period_plot,
                    **period_model_dict,
                },
                code=open(
                    os.path.join(
                        os.path.dirname(__file__), '../static/js/plotjs', 'foldphase.js'
                    )
                ).read(),
            ),
        )
        # a way to modify the period
        period_double_button = Button(label="*2", width=30)
        period_double_button.js_on_click(
            CustomJS(
                args={'textinput': period_textinput},
                code="""
                const period = parseFloat(textinput.value);
                textinput.value = parseFloat(2.*period).toFixed(9);
                """,
            )
        )
        period_halve_button = Button(label="/2", width=30)
        period_halve_button.js_on_click(
            CustomJS(
                args={'textinput': period_textinput},
                code="""
                        const period = parseFloat(textinput.value);
                        textinput.value = parseFloat(period/2.).toFixed(9);
                        """,
            )
        )
        # a way to select the period
        period_selection.js_on_click(
            CustomJS(
                args={'textinput': period_textinput, 'periods': period_list},
                code="""
                textinput.value = parseFloat(periods[this.active]).toFixed(9);
                """,
            )
        )
        phase_selection.js_on_click(
            CustomJS(
                args={
                    'textinput': period_textinput,
                    'numphases': phase_selection,
                    'n_labels': len(split),
                    'p': period_plot,
                    **period_model_dict,
                },
                code=open(
                    os.path.join(
                        os.path.dirname(__file__), '../static/js/plotjs', 'foldphase.js'
                    )
                ).read(),
            )
        )

        # layout
        if device == "mobile_portrait":
            period_controls = column(
                row(
                    period_title,
                    period_textinput,
                    period_double_button,
                    period_halve_button,
                    width=width,
                    sizing_mode="scale_width",
                ),
                phase_selection,
                period_selection,
                width=width,
            )
            # Add extra height to plot based on period control components added
            # 18 is the height of each period selection radio option (per default font size)
            # and the 130 encompasses the other components which are consistent no matter
            # the data size.
            height += 130 + 18 * len(period_labels)
        else:
            period_controls = column(
                row(
                    period_title,
                    period_textinput,
                    period_double_button,
                    period_halve_button,
                    phase_selection,
                    width=width,
                    sizing_mode="scale_width",
                ),
                period_selection,
                margin=10,
            )
            # Add extra height to plot based on period control components added
            # Numbers are derived in similar manner to the "mobile_portrait" case above
            height += 90 + 18 * len(period_labels)

        period_layout = column(period_plot, period_controls, width=width, height=height)

        # Period panel
        p3 = Panel(child=period_layout, title='Period')

        # tabs for mag, flux, period
        tabs = Tabs(tabs=[p2, p1, p3], width=width, height=height, sizing_mode='fixed')
    else:
        # tabs for mag, flux
        tabs = Tabs(tabs=[p2, p1], width=width, height=height + 90, sizing_mode='fixed')
    return bokeh_embed.json_item(tabs)
Exemplo n.º 12
0
def spectroscopy_plot(obj_id, user, spec_id=None, width=600, device="browser"):
    obj = Obj.query.get(obj_id)
    spectra = (
        DBSession()
        .query(Spectrum)
        .join(Obj)
        .join(GroupSpectrum)
        .filter(
            Spectrum.obj_id == obj_id,
            GroupSpectrum.group_id.in_([g.id for g in user.accessible_groups]),
        )
    ).all()

    if spec_id is not None:
        spectra = [spec for spec in spectra if spec.id == int(spec_id)]
    if len(spectra) == 0:
        return None, None, None

    rainbow = cm.get_cmap('rainbow', len(spectra))
    palette = list(map(rgb2hex, rainbow(range(len(spectra)))))
    color_map = dict(zip([s.id for s in spectra], palette))

    data = []
    for i, s in enumerate(spectra):

        # normalize spectra to a median flux of 1 for easy comparison
        normfac = np.nanmedian(np.abs(s.fluxes))
        normfac = normfac if normfac != 0.0 else 1e-20

        df = pd.DataFrame(
            {
                'wavelength': s.wavelengths,
                'flux': s.fluxes / normfac,
                'id': s.id,
                'telescope': s.instrument.telescope.name,
                'instrument': s.instrument.name,
                'date_observed': s.observed_at.isoformat(sep=' ', timespec='seconds'),
                'pi': (
                    s.assignment.run.pi
                    if s.assignment is not None
                    else (
                        s.followup_request.allocation.pi
                        if s.followup_request is not None
                        else ""
                    )
                ),
            }
        )
        data.append(df)
    data = pd.concat(data)

    data.sort_values(by=['date_observed', 'wavelength'], inplace=True)

    dfs = []
    for i, s in enumerate(spectra):
        # Smooth the spectrum by using a rolling average
        df = (
            pd.DataFrame({'wavelength': s.wavelengths, 'flux': s.fluxes})
            .rolling(2)
            .mean(numeric_only=True)
            .dropna()
        )
        dfs.append(df)

    smoothed_data = pd.concat(dfs)

    split = data.groupby('id', sort=False)

    hover = HoverTool(
        tooltips=[
            ('wavelength', '@wavelength{0,0.000}'),
            ('flux', '@flux'),
            ('telesecope', '@telescope'),
            ('instrument', '@instrument'),
            ('UTC date observed', '@date_observed'),
            ('PI', '@pi'),
        ]
    )
    smoothed_max = np.max(smoothed_data['flux'])
    smoothed_min = np.min(smoothed_data['flux'])
    ymax = smoothed_max * 1.05
    ymin = smoothed_min - 0.05 * (smoothed_max - smoothed_min)
    xmin = np.min(data['wavelength']) - 100
    xmax = np.max(data['wavelength']) + 100
    if obj.redshift is not None and obj.redshift > 0:
        xmin_rest = xmin / (1.0 + obj.redshift)
        xmax_rest = xmax / (1.0 + obj.redshift)

    active_drag = None if "mobile" in device or "tablet" in device else "box_zoom"
    tools = (
        "box_zoom, pan, reset"
        if "mobile" in device or "tablet" in device
        else "box_zoom,wheel_zoom,pan,reset"
    )

    # These values are equivalent from the photometry plot values
    frame_width = width - 64
    if device == "mobile_portrait":
        legend_items_per_row = 1
        legend_row_height = 24
        aspect_ratio = 1
    elif device == "mobile_landscape":
        legend_items_per_row = 4
        legend_row_height = 50
        aspect_ratio = 1.8
    elif device == "tablet_portrait":
        legend_items_per_row = 5
        legend_row_height = 50
        aspect_ratio = 1.5
    elif device == "tablet_landscape":
        legend_items_per_row = 7
        legend_row_height = 50
        aspect_ratio = 1.8
    elif device == "browser":
        frame_width = width - 200
    plot_height = (
        400
        if device == "browser"
        else math.floor(width / aspect_ratio)
        + legend_row_height * int(len(split) / legend_items_per_row)
        + 30  # 30 is the height of the toolbar
    )

    plot = figure(
        frame_width=frame_width,
        height=plot_height,
        y_range=(ymin, ymax),
        x_range=(xmin, xmax),
        tools=tools,
        toolbar_location="above",
        active_drag=active_drag,
    )
    plot.add_tools(hover)
    model_dict = {}
    legend_items = []
    for i, (key, df) in enumerate(split):
        renderers = []
        s = Spectrum.query.get(key)
        label = f'{s.instrument.name} ({s.observed_at.date().strftime("%m/%d/%y")})'
        model_dict['s' + str(i)] = plot.step(
            x='wavelength',
            y='flux',
            color=color_map[key],
            source=ColumnDataSource(df),
        )
        renderers.append(model_dict['s' + str(i)])
        legend_items.append(LegendItem(label=label, renderers=renderers))
        model_dict['l' + str(i)] = plot.line(
            x='wavelength',
            y='flux',
            color=color_map[key],
            source=ColumnDataSource(df),
            line_alpha=0.0,
        )
    plot.xaxis.axis_label = 'Wavelength (Å)'
    plot.yaxis.axis_label = 'Flux'
    plot.toolbar.logo = None
    if obj.redshift is not None and obj.redshift > 0:
        plot.extra_x_ranges = {"rest_wave": Range1d(start=xmin_rest, end=xmax_rest)}
        plot.add_layout(
            LinearAxis(x_range_name="rest_wave", axis_label="Rest Wavelength (Å)"),
            'above',
        )

    # TODO how to choose a good default?
    plot.y_range = Range1d(0, 1.03 * data.flux.max())

    legend_loc = "below" if "mobile" in device or "tablet" in device else "right"
    legend_orientation = (
        "vertical" if device in ["browser", "mobile_portrait"] else "horizontal"
    )

    add_plot_legend(plot, legend_items, width, legend_orientation, legend_loc)

    # 20 is for padding
    slider_width = width if "mobile" in device else int(width / 2) - 20
    z_title = Div(text="Redshift (<i>z</i>): ")
    z_slider = Slider(
        value=obj.redshift if obj.redshift is not None else 0.0,
        start=0.0,
        end=3.0,
        step=0.001,
        show_value=False,
        format="0[.]000",
    )
    z_textinput = TextInput(
        value=str(obj.redshift if obj.redshift is not None else 0.0)
    )
    z_slider.js_on_change(
        'value',
        CustomJS(
            args={'slider': z_slider, 'textinput': z_textinput},
            code="""
            textinput.value = parseFloat(slider.value).toFixed(3);
            textinput.change.emit();
        """,
        ),
    )
    z = column(
        z_title,
        z_slider,
        z_textinput,
        width=slider_width,
        margin=(4, 10, 0, 10),
    )

    v_title = Div(text="<i>V</i><sub>expansion</sub> (km/s): ")
    v_exp_slider = Slider(
        value=0.0,
        start=0.0,
        end=3e4,
        step=10.0,
        show_value=False,
    )
    v_exp_textinput = TextInput(value='0')
    v_exp_slider.js_on_change(
        'value',
        CustomJS(
            args={'slider': v_exp_slider, 'textinput': v_exp_textinput},
            code="""
            textinput.value = parseFloat(slider.value).toFixed(0);
            textinput.change.emit();
        """,
        ),
    )
    v_exp = column(
        v_title,
        v_exp_slider,
        v_exp_textinput,
        width=slider_width,
        margin=(0, 10, 0, 10),
    )

    for i, (wavelengths, color) in enumerate(SPEC_LINES.values()):
        el_data = pd.DataFrame({'wavelength': wavelengths})
        obj_redshift = 0 if obj.redshift is None else obj.redshift
        el_data['x'] = el_data['wavelength'] * (1.0 + obj_redshift)
        model_dict[f'el{i}'] = plot.segment(
            x0='x',
            x1='x',
            # TODO change limits
            y0=0,
            y1=1e4,
            color=color,
            source=ColumnDataSource(el_data),
        )
        model_dict[f'el{i}'].visible = False

    # Split spectral line legend into columns
    if device == "mobile_portrait":
        columns = 3
    elif device == "mobile_landscape":
        columns = 5
    else:
        columns = 7
    element_dicts = zip(*itertools.zip_longest(*[iter(SPEC_LINES.items())] * columns))

    elements_groups = []  # The Bokeh checkbox groups
    callbacks = []  # The checkbox callbacks for each element
    for column_idx, element_dict in enumerate(element_dicts):
        element_dict = [e for e in element_dict if e is not None]
        labels = [key for key, value in element_dict]
        colors = [c for key, (w, c) in element_dict]
        elements = CheckboxWithLegendGroup(
            labels=labels, active=[], colors=colors, width=width // (columns + 1)
        )
        elements_groups.append(elements)

        callback = CustomJS(
            args={
                'elements': elements,
                'z': z_textinput,
                'v_exp': v_exp_textinput,
                **model_dict,
            },
            code=f"""
            let c = 299792.458; // speed of light in km / s
            const i_max = {column_idx} +  {columns} * elements.labels.length;
            let local_i = 0;
            for (let i = {column_idx}; i < i_max; i = i + {columns}) {{
                let el = eval("el" + i);
                el.visible = (elements.active.includes(local_i))
                el.data_source.data.x = el.data_source.data.wavelength.map(
                    x_i => (x_i * (1 + parseFloat(z.value)) /
                                    (1 + parseFloat(v_exp.value) / c))
                );
                el.data_source.change.emit();
                local_i++;
            }}
        """,
        )
        elements.js_on_click(callback)
        callbacks.append(callback)

    z_textinput.js_on_change(
        'value',
        CustomJS(
            args={
                'z': z_textinput,
                'slider': z_slider,
                'v_exp': v_exp_textinput,
                **model_dict,
            },
            code="""
            // Update slider value to match text input
            slider.value = parseFloat(z.value).toFixed(3);
        """,
        ),
    )

    v_exp_textinput.js_on_change(
        'value',
        CustomJS(
            args={
                'z': z_textinput,
                'slider': v_exp_slider,
                'v_exp': v_exp_textinput,
                **model_dict,
            },
            code="""
            // Update slider value to match text input
            slider.value = parseFloat(v_exp.value).toFixed(3);
        """,
        ),
    )

    # Update the element spectral lines as well
    for callback in callbacks:
        z_textinput.js_on_change('value', callback)
        v_exp_textinput.js_on_change('value', callback)

    # Add some height for the checkboxes and sliders
    if device == "mobile_portrait":
        height = plot_height + 400
    elif device == "mobile_landscape":
        height = plot_height + 350
    else:
        height = plot_height + 200

    row2 = row(elements_groups)
    row3 = column(z, v_exp) if "mobile" in device else row(z, v_exp)
    layout = column(
        plot,
        row2,
        row3,
        sizing_mode='stretch_width',
        width=width,
        height=height,
    )
    return bokeh_embed.json_item(layout)
Exemplo n.º 13
0
def photometry_plot(obj_id, user, width=600, height=300, device="browser"):
    """Create object photometry scatter plot.

    Parameters
    ----------
    obj_id : str
        ID of Obj to be plotted.

    Returns
    -------
    dict
        Returns Bokeh JSON embedding for the desired plot.
    """

    data = pd.read_sql(
        DBSession().query(
            Photometry,
            Telescope.nickname.label("telescope"),
            Instrument.name.label("instrument"),
        ).join(Instrument, Instrument.id == Photometry.instrument_id).join(
            Telescope, Telescope.id == Instrument.telescope_id).filter(
                Photometry.obj_id == obj_id).filter(
                    Photometry.groups.any(
                        Group.id.in_([g.id for g in user.accessible_groups
                                      ]))).statement,
        DBSession().bind,
    )

    if data.empty:
        return None, None, None

    # get spectra to annotate on phot plots
    spectra = (Spectrum.query_records_accessible_by(user).filter(
        Spectrum.obj_id == obj_id).all())

    data['color'] = [get_color(f) for f in data['filter']]

    labels = []
    for i, datarow in data.iterrows():
        label = f'{datarow["instrument"]}/{datarow["filter"]}'
        if datarow['origin'] is not None:
            label += f'/{datarow["origin"]}'
        labels.append(label)

    data['label'] = labels
    data['zp'] = PHOT_ZP
    data['magsys'] = 'ab'
    data['alpha'] = 1.0
    data['lim_mag'] = (
        -2.5 * np.log10(data['fluxerr'] * PHOT_DETECTION_THRESHOLD) +
        data['zp'])

    # Passing a dictionary to a bokeh datasource causes the frontend to die,
    # deleting the dictionary column fixes that
    del data['original_user_data']

    # keep track of things that are only upper limits
    data['hasflux'] = ~data['flux'].isna()

    # calculate the magnitudes - a photometry point is considered "significant"
    # or "detected" (and thus can be represented by a magnitude) if its snr
    # is above PHOT_DETECTION_THRESHOLD
    obsind = data['hasflux'] & (data['flux'].fillna(0.0) / data['fluxerr'] >=
                                PHOT_DETECTION_THRESHOLD)
    data.loc[~obsind, 'mag'] = None
    data.loc[obsind, 'mag'] = -2.5 * np.log10(data[obsind]['flux']) + PHOT_ZP

    # calculate the magnitude errors using standard error propagation formulae
    # https://en.wikipedia.org/wiki/Propagation_of_uncertainty#Example_formulae
    data.loc[~obsind, 'magerr'] = None
    coeff = 2.5 / np.log(10)
    magerrs = np.abs(coeff * data[obsind]['fluxerr'] / data[obsind]['flux'])
    data.loc[obsind, 'magerr'] = magerrs
    data['obs'] = obsind
    data['stacked'] = False

    split = data.groupby('label', sort=False)

    finite = np.isfinite(data['flux'])
    fdata = data[finite]
    lower = np.min(fdata['flux']) * 0.95
    upper = np.max(fdata['flux']) * 1.05

    active_drag = None if "mobile" in device or "tablet" in device else "box_zoom"
    tools = ('box_zoom,pan,reset' if "mobile" in device or "tablet" in device
             else "box_zoom,wheel_zoom,pan,reset,save")

    plot = figure(
        aspect_ratio=2.0 if device == "mobile_landscape" else 1.5,
        sizing_mode='scale_both',
        active_drag=active_drag,
        tools=tools,
        toolbar_location='above',
        toolbar_sticky=True,
        y_range=(lower, upper),
        min_border_right=16,
    )
    imhover = HoverTool(tooltips=tooltip_format)
    imhover.renderers = []
    plot.add_tools(imhover)

    model_dict = {}

    for i, (label, sdf) in enumerate(split):

        # for the flux plot, we only show things that have a flux value
        df = sdf[sdf['hasflux']]

        key = f'obs{i}'
        model_dict[key] = plot.scatter(
            x='mjd',
            y='flux',
            color='color',
            marker='circle',
            fill_color='color',
            alpha='alpha',
            source=ColumnDataSource(df),
        )

        imhover.renderers.append(model_dict[key])

        key = f'bin{i}'
        model_dict[key] = plot.scatter(
            x='mjd',
            y='flux',
            color='color',
            marker='circle',
            fill_color='color',
            source=ColumnDataSource(data=dict(
                mjd=[],
                flux=[],
                fluxerr=[],
                filter=[],
                color=[],
                lim_mag=[],
                mag=[],
                magerr=[],
                stacked=[],
                instrument=[],
            )),
        )

        imhover.renderers.append(model_dict[key])

        key = 'obserr' + str(i)
        y_err_x = []
        y_err_y = []

        for d, ro in df.iterrows():
            px = ro['mjd']
            py = ro['flux']
            err = ro['fluxerr']

            y_err_x.append((px, px))
            y_err_y.append((py - err, py + err))

        model_dict[key] = plot.multi_line(
            xs='xs',
            ys='ys',
            color='color',
            alpha='alpha',
            source=ColumnDataSource(data=dict(xs=y_err_x,
                                              ys=y_err_y,
                                              color=df['color'],
                                              alpha=[1.0] * len(df))),
        )

        key = f'binerr{i}'
        model_dict[key] = plot.multi_line(
            xs='xs',
            ys='ys',
            color='color',
            source=ColumnDataSource(data=dict(xs=[], ys=[], color=[])),
        )

    plot.xaxis.axis_label = 'MJD'
    if device == "mobile_portrait":
        plot.xaxis.ticker.desired_num_ticks = 5
    plot.yaxis.axis_label = 'Flux (μJy)'
    plot.toolbar.logo = None

    colors_labels = data[['color', 'label']].drop_duplicates()

    toggle = CheckboxWithLegendGroup(
        labels=colors_labels.label.tolist(),
        active=list(range(len(colors_labels))),
        colors=colors_labels.color.tolist(),
        width=width // 5,
        inline=True if "tablet" in device else False,
    )

    # TODO replace `eval` with Namespaces
    # https://github.com/bokeh/bokeh/pull/6340
    toggle.js_on_click(
        CustomJS(
            args={
                'toggle': toggle,
                **model_dict
            },
            code=open(
                os.path.join(os.path.dirname(__file__), '../static/js/plotjs',
                             'togglef.js')).read(),
        ))

    slider = Slider(
        start=0.0,
        end=15.0,
        value=0.0,
        step=1.0,
        title='Binsize (days)',
        max_width=350,
        margin=(4, 10, 0, 10),
    )

    callback = CustomJS(
        args={
            'slider': slider,
            'toggle': toggle,
            **model_dict
        },
        code=open(
            os.path.join(os.path.dirname(__file__),
                         '../static/js/plotjs', 'stackf.js')).read().replace(
                             'default_zp', str(PHOT_ZP)).replace(
                                 'detect_thresh',
                                 str(PHOT_DETECTION_THRESHOLD)),
    )

    slider.js_on_change('value', callback)

    # Mark the first and last detections
    detection_dates = data[data['hasflux']]['mjd']
    if len(detection_dates) > 0:
        first = round(detection_dates.min(), 6)
        last = round(detection_dates.max(), 6)
        first_color = "#34b4eb"
        last_color = "#8992f5"
        midpoint = (upper + lower) / 2
        line_top = 5 * upper - 4 * midpoint
        line_bottom = 5 * lower - 4 * midpoint
        y = np.linspace(line_bottom, line_top, num=5000)
        first_r = plot.line(
            x=np.full(5000, first),
            y=y,
            line_alpha=0.5,
            line_color=first_color,
            line_width=2,
        )
        plot.add_tools(
            HoverTool(
                tooltips=[("First detection", f'{first}')],
                renderers=[first_r],
            ))
        last_r = plot.line(
            x=np.full(5000, last),
            y=y,
            line_alpha=0.5,
            line_color=last_color,
            line_width=2,
        )
        plot.add_tools(
            HoverTool(
                tooltips=[("Last detection", f'{last}')],
                renderers=[last_r],
            ))

    # Mark when spectra were taken
    annotate_spec(plot, spectra, lower, upper)

    plot_layout = (column(plot, toggle) if "mobile" in device
                   or "tablet" in device else row(plot, toggle))
    layout = column(slider,
                    plot_layout,
                    sizing_mode='scale_width',
                    width=width)

    p1 = Panel(child=layout, title='Flux')

    # now make the mag light curve
    ymax = (np.nanmax((
        np.nanmax(data.loc[obsind, 'mag']) if any(obsind) else np.nan,
        np.nanmax(data.loc[~obsind, 'lim_mag']) if any(~obsind) else np.nan,
    )) + 0.1)
    ymin = (np.nanmin((
        np.nanmin(data.loc[obsind, 'mag']) if any(obsind) else np.nan,
        np.nanmin(data.loc[~obsind, 'lim_mag']) if any(~obsind) else np.nan,
    )) - 0.1)

    xmin = data['mjd'].min() - 2
    xmax = data['mjd'].max() + 2

    plot = figure(
        aspect_ratio=2.0 if device == "mobile_landscape" else 1.5,
        sizing_mode='scale_both',
        width=width,
        active_drag=active_drag,
        tools=tools,
        y_range=(ymax, ymin),
        x_range=(xmin, xmax),
        toolbar_location='above',
        toolbar_sticky=True,
        x_axis_location='above',
    )

    # Mark the first and last detections again
    detection_dates = data[obsind]['mjd']
    if len(detection_dates) > 0:
        first = round(detection_dates.min(), 6)
        last = round(detection_dates.max(), 6)
        midpoint = (ymax + ymin) / 2
        line_top = 5 * ymax - 4 * midpoint
        line_bottom = 5 * ymin - 4 * midpoint
        y = np.linspace(line_bottom, line_top, num=5000)
        first_r = plot.line(
            x=np.full(5000, first),
            y=y,
            line_alpha=0.5,
            line_color=first_color,
            line_width=2,
        )
        plot.add_tools(
            HoverTool(
                tooltips=[("First detection", f'{first}')],
                renderers=[first_r],
            ))
        last_r = plot.line(
            x=np.full(5000, last),
            y=y,
            line_alpha=0.5,
            line_color=last_color,
            line_width=2,
        )
        plot.add_tools(
            HoverTool(
                tooltips=[("Last detection", f'{last}')],
                renderers=[last_r],
                point_policy='follow_mouse',
            ))

    # Mark when spectra were taken
    annotate_spec(plot, spectra, ymax, ymin)

    imhover = HoverTool(tooltips=tooltip_format)
    imhover.renderers = []
    plot.add_tools(imhover)

    model_dict = {}

    for i, (label, df) in enumerate(split):

        key = f'obs{i}'
        model_dict[key] = plot.scatter(
            x='mjd',
            y='mag',
            color='color',
            marker='circle',
            fill_color='color',
            alpha='alpha',
            source=ColumnDataSource(df[df['obs']]),
        )

        imhover.renderers.append(model_dict[key])

        unobs_source = df[~df['obs']].copy()
        unobs_source.loc[:, 'alpha'] = 0.8

        key = f'unobs{i}'
        model_dict[key] = plot.scatter(
            x='mjd',
            y='lim_mag',
            color='color',
            marker='inverted_triangle',
            fill_color='white',
            line_color='color',
            alpha='alpha',
            source=ColumnDataSource(unobs_source),
        )

        imhover.renderers.append(model_dict[key])

        key = f'bin{i}'
        model_dict[key] = plot.scatter(
            x='mjd',
            y='mag',
            color='color',
            marker='circle',
            fill_color='color',
            source=ColumnDataSource(data=dict(
                mjd=[],
                flux=[],
                fluxerr=[],
                filter=[],
                color=[],
                lim_mag=[],
                mag=[],
                magerr=[],
                instrument=[],
                stacked=[],
            )),
        )

        imhover.renderers.append(model_dict[key])

        key = 'obserr' + str(i)
        y_err_x = []
        y_err_y = []

        for d, ro in df[df['obs']].iterrows():
            px = ro['mjd']
            py = ro['mag']
            err = ro['magerr']

            y_err_x.append((px, px))
            y_err_y.append((py - err, py + err))

        model_dict[key] = plot.multi_line(
            xs='xs',
            ys='ys',
            color='color',
            alpha='alpha',
            source=ColumnDataSource(data=dict(
                xs=y_err_x,
                ys=y_err_y,
                color=df[df['obs']]['color'],
                alpha=[1.0] * len(df[df['obs']]),
            )),
        )

        key = f'binerr{i}'
        model_dict[key] = plot.multi_line(
            xs='xs',
            ys='ys',
            color='color',
            source=ColumnDataSource(data=dict(xs=[], ys=[], color=[])),
        )

        key = f'unobsbin{i}'
        model_dict[key] = plot.scatter(
            x='mjd',
            y='lim_mag',
            color='color',
            marker='inverted_triangle',
            fill_color='white',
            line_color='color',
            alpha=0.8,
            source=ColumnDataSource(data=dict(
                mjd=[],
                flux=[],
                fluxerr=[],
                filter=[],
                color=[],
                lim_mag=[],
                mag=[],
                magerr=[],
                instrument=[],
                stacked=[],
            )),
        )
        imhover.renderers.append(model_dict[key])

        key = f'all{i}'
        model_dict[key] = ColumnDataSource(df)

        key = f'bold{i}'
        model_dict[key] = ColumnDataSource(df[[
            'mjd',
            'flux',
            'fluxerr',
            'mag',
            'magerr',
            'filter',
            'zp',
            'magsys',
            'lim_mag',
            'stacked',
        ]])

    plot.xaxis.axis_label = 'MJD'
    plot.yaxis.axis_label = 'AB mag'
    plot.toolbar.logo = None

    obj = DBSession().query(Obj).get(obj_id)
    if obj.dm is not None:
        plot.extra_y_ranges = {
            "Absolute Mag": Range1d(start=ymax - obj.dm, end=ymin - obj.dm)
        }
        plot.add_layout(
            LinearAxis(y_range_name="Absolute Mag", axis_label="m - DM"),
            'right')

    now = Time.now().mjd
    plot.extra_x_ranges = {
        "Days Ago": Range1d(start=now - xmin, end=now - xmax)
    }
    plot.add_layout(LinearAxis(x_range_name="Days Ago", axis_label="Days Ago"),
                    'below')

    colors_labels = data[['color', 'label']].drop_duplicates()

    toggle = CheckboxWithLegendGroup(
        labels=colors_labels.label.tolist(),
        active=list(range(len(colors_labels))),
        colors=colors_labels.color.tolist(),
        width=width // 5,
        inline=True if "tablet" in device else False,
    )

    # TODO replace `eval` with Namespaces
    # https://github.com/bokeh/bokeh/pull/6340
    toggle.js_on_click(
        CustomJS(
            args={
                'toggle': toggle,
                **model_dict
            },
            code=open(
                os.path.join(os.path.dirname(__file__), '../static/js/plotjs',
                             'togglem.js')).read(),
        ))

    slider = Slider(
        start=0.0,
        end=15.0,
        value=0.0,
        step=1.0,
        title='Binsize (days)',
        max_width=350,
        margin=(4, 10, 0, 10),
    )

    button = Button(label="Export Bold Light Curve to CSV")
    button.js_on_click(
        CustomJS(
            args={
                'slider': slider,
                'toggle': toggle,
                **model_dict
            },
            code=open(
                os.path.join(os.path.dirname(__file__), '../static/js/plotjs',
                             "download.js")).read().replace('objname',
                                                            obj_id).replace(
                                                                'default_zp',
                                                                str(PHOT_ZP)),
        ))

    # Don't need to expose CSV download on mobile
    top_layout = (slider if "mobile" in device or "tablet" in device else row(
        slider, button))

    callback = CustomJS(
        args={
            'slider': slider,
            'toggle': toggle,
            **model_dict
        },
        code=open(
            os.path.join(os.path.dirname(__file__),
                         '../static/js/plotjs', 'stackm.js')).read().replace(
                             'default_zp', str(PHOT_ZP)).replace(
                                 'detect_thresh',
                                 str(PHOT_DETECTION_THRESHOLD)),
    )
    slider.js_on_change('value', callback)
    plot_layout = (column(plot, toggle) if "mobile" in device
                   or "tablet" in device else row(plot, toggle))
    layout = column(top_layout,
                    plot_layout,
                    sizing_mode='scale_width',
                    width=width)

    p2 = Panel(child=layout, title='Mag')

    # now make period plot

    # get periods from annotations
    annotation_list = obj.get_annotations_readable_by(user)
    period_labels = []
    period_list = []
    for an in annotation_list:
        if 'period' in an.data:
            period_list.append(an.data['period'])
            period_labels.append(an.origin + ": %.9f" % an.data['period'])

    if len(period_list) > 0:
        period = period_list[0]
    else:
        period = None

    # don't generate if no period annotated
    if period is not None:

        # bokeh figure for period plotting
        period_plot = figure(
            aspect_ratio=1.5,
            sizing_mode='scale_both',
            active_drag='box_zoom',
            tools='box_zoom,wheel_zoom,pan,reset,save',
            y_range=(ymax, ymin),
            x_range=(-0.1, 1.1),  # initially one phase
            toolbar_location='above',
            toolbar_sticky=False,
            x_axis_location='below',
        )

        # axis labels
        period_plot.xaxis.axis_label = 'phase'
        period_plot.yaxis.axis_label = 'mag'
        period_plot.toolbar.logo = None

        # do we have a distance modulus (dm)?
        obj = DBSession().query(Obj).get(obj_id)
        if obj.dm is not None:
            period_plot.extra_y_ranges = {
                "Absolute Mag": Range1d(start=ymax - obj.dm, end=ymin - obj.dm)
            }
            period_plot.add_layout(
                LinearAxis(y_range_name="Absolute Mag", axis_label="m - DM"),
                'right')

        # initiate hover tool
        period_imhover = HoverTool(tooltips=tooltip_format)
        period_imhover.renderers = []
        period_plot.add_tools(period_imhover)

        # initiate period radio buttons
        period_selection = RadioGroup(labels=period_labels, active=0)

        phase_selection = RadioGroup(labels=["One phase", "Two phases"],
                                     active=0)

        # store all the plot data
        period_model_dict = {}

        # iterate over each filter
        for i, (label, df) in enumerate(split):

            # fold x-axis on period in days
            df['mjd_folda'] = (df['mjd'] % period) / period
            df['mjd_foldb'] = df['mjd_folda'] + 1.0

            # phase plotting
            for ph in ['a', 'b']:
                key = 'fold' + ph + f'{i}'
                period_model_dict[key] = period_plot.scatter(
                    x='mjd_fold' + ph,
                    y='mag',
                    color='color',
                    marker='circle',
                    fill_color='color',
                    alpha='alpha',
                    visible=('a' in ph),
                    source=ColumnDataSource(
                        df[df['obs']]),  # only visible data
                )
                # add to hover tool
                period_imhover.renderers.append(period_model_dict[key])

                # errorbars for phases
                key = 'fold' + ph + f'err{i}'
                y_err_x = []
                y_err_y = []

                # get each visible error value
                for d, ro in df[df['obs']].iterrows():
                    px = ro['mjd_fold' + ph]
                    py = ro['mag']
                    err = ro['magerr']
                    # set up error tuples
                    y_err_x.append((px, px))
                    y_err_y.append((py - err, py + err))
                # plot phase errors
                period_model_dict[key] = period_plot.multi_line(
                    xs='xs',
                    ys='ys',
                    color='color',
                    alpha='alpha',
                    visible=('a' in ph),
                    source=ColumnDataSource(data=dict(
                        xs=y_err_x,
                        ys=y_err_y,
                        color=df[df['obs']]['color'],
                        alpha=[1.0] * len(df[df['obs']]),
                    )),
                )

        # toggle for folded photometry
        period_toggle = CheckboxWithLegendGroup(
            labels=colors_labels.label.tolist(),
            active=list(range(len(colors_labels))),
            colors=colors_labels.color.tolist(),
            width=width // 5,
        )
        # use javascript to perform toggling on click
        # TODO replace `eval` with Namespaces
        # https://github.com/bokeh/bokeh/pull/6340
        period_toggle.js_on_click(
            CustomJS(
                args={
                    'toggle': period_toggle,
                    'numphases': phase_selection,
                    'p': period_plot,
                    **period_model_dict,
                },
                code=open(
                    os.path.join(os.path.dirname(__file__),
                                 '../static/js/plotjs', 'togglep.js')).read(),
            ))

        # set up period adjustment text box
        period_title = Div(text="Period (days): ")
        period_textinput = TextInput(
            value=str(period if period is not None else 0.0))
        period_textinput.js_on_change(
            'value',
            CustomJS(
                args={
                    'textinput': period_textinput,
                    'toggle': period_toggle,
                    'numphases': phase_selection,
                    'p': period_plot,
                    **period_model_dict,
                },
                code=open(
                    os.path.join(os.path.dirname(__file__),
                                 '../static/js/plotjs',
                                 'foldphase.js')).read(),
            ),
        )
        # a way to modify the period
        period_double_button = Button(label="*2")
        period_double_button.js_on_click(
            CustomJS(
                args={'textinput': period_textinput},
                code="""
                const period = parseFloat(textinput.value);
                textinput.value = parseFloat(2.*period).toFixed(9);
                """,
            ))
        period_halve_button = Button(label="/2")
        period_halve_button.js_on_click(
            CustomJS(
                args={'textinput': period_textinput},
                code="""
                        const period = parseFloat(textinput.value);
                        textinput.value = parseFloat(period/2.).toFixed(9);
                        """,
            ))
        # a way to select the period
        period_selection.js_on_click(
            CustomJS(
                args={
                    'textinput': period_textinput,
                    'periods': period_list
                },
                code="""
                textinput.value = parseFloat(periods[this.active]).toFixed(9);
                """,
            ))
        phase_selection.js_on_click(
            CustomJS(
                args={
                    'textinput': period_textinput,
                    'toggle': period_toggle,
                    'numphases': phase_selection,
                    'p': period_plot,
                    **period_model_dict,
                },
                code=open(
                    os.path.join(os.path.dirname(__file__),
                                 '../static/js/plotjs',
                                 'foldphase.js')).read(),
            ))

        # layout

        period_column = column(
            period_toggle,
            period_title,
            period_textinput,
            period_selection,
            row(period_double_button, period_halve_button, width=180),
            phase_selection,
            width=180,
        )

        period_layout = column(
            row(period_plot, period_column),
            sizing_mode='scale_width',
            width=width,
        )

        # Period panel
        p3 = Panel(child=period_layout, title='Period')
        # tabs for mag, flux, period
        tabs = Tabs(tabs=[p2, p1, p3],
                    width=width,
                    height=height,
                    sizing_mode='fixed')
    else:
        # tabs for mag, flux
        tabs = Tabs(tabs=[p2, p1],
                    width=width,
                    height=height,
                    sizing_mode='fixed')
    return bokeh_embed.json_item(tabs)
Exemplo n.º 14
0
class ViewerWidgets(object):
    """ 
    Encapsulates Bokeh widgets, and related callbacks, that are part of prospect's GUI.
        Except for VI widgets
    """
    
    def __init__(self, plots, nspec):
        self.js_files = get_resources('js')
        self.navigation_button_width = 30
        self.z_button_width = 30
        self.plot_widget_width = (plots.plot_width+(plots.plot_height//2))//2 - 40 # used for widgets scaling
    
        #-----
        #- Ifiberslider and smoothing widgets
        # Ifiberslider's value controls which spectrum is displayed
        # These two widgets call update_plot(), later defined
        slider_end = nspec-1 if nspec > 1 else 0.5 # Slider cannot have start=end
        self.ifiberslider = Slider(start=0, end=slider_end, value=0, step=1, title='Spectrum (of '+str(nspec)+')')
        self.smootherslider = Slider(start=0, end=26, value=0, step=1.0, title='Gaussian Sigma Smooth')
        self.coaddcam_buttons = None
        self.model_select = None


    def add_navigation(self, nspec):
        #-----
        #- Navigation buttons
        self.prev_button = Button(label="<", width=self.navigation_button_width)
        self.next_button = Button(label=">", width=self.navigation_button_width)
        self.prev_callback = CustomJS(
            args=dict(ifiberslider=self.ifiberslider),
            code="""
            if(ifiberslider.value>0 && ifiberslider.end>=1) {
                ifiberslider.value--
            }
            """)
        self.next_callback = CustomJS(
            args=dict(ifiberslider=self.ifiberslider, nspec=nspec),
            code="""
            if(ifiberslider.value<nspec-1 && ifiberslider.end>=1) {
                ifiberslider.value++
            }
            """)
        self.prev_button.js_on_event('button_click', self.prev_callback)
        self.next_button.js_on_event('button_click', self.next_callback)

    def add_resetrange(self, viewer_cds, plots):
        #-----
        #- Axis reset button (superseeds the default bokeh "reset"
        self.reset_plotrange_button = Button(label="Reset X-Y range", button_type="default")
        reset_plotrange_code = self.js_files["adapt_plotrange.js"] + self.js_files["reset_plotrange.js"]
        self.reset_plotrange_callback = CustomJS(args = dict(fig=plots.fig, xmin=plots.xmin, xmax=plots.xmax, spectra=viewer_cds.cds_spectra),
                                            code = reset_plotrange_code)
        self.reset_plotrange_button.js_on_event('button_click', self.reset_plotrange_callback)


    def add_redshift_widgets(self, z, viewer_cds, plots):
        ## TODO handle "z" (same issue as viewerplots TBD)

        #-----
        #- Redshift / wavelength scale widgets
        z1 = np.floor(z*100)/100
        dz = z-z1
        self.zslider = Slider(start=-0.1, end=5.0, value=z1, step=0.01, title='Redshift rough tuning')
        self.dzslider = Slider(start=0.0, end=0.0099, value=dz, step=0.0001, title='Redshift fine-tuning')
        self.dzslider.format = "0[.]0000"
        self.z_input = TextInput(value="{:.4f}".format(z), title="Redshift value:")

        #- Observer vs. Rest frame wavelengths
        self.waveframe_buttons = RadioButtonGroup(
            labels=["Obs", "Rest"], active=0)

        self.zslider_callback  = CustomJS(
            args=dict(zslider=self.zslider, dzslider=self.dzslider, z_input=self.z_input),
            code="""
            // Protect against 1) recursive call with z_input callback;
            //   2) out-of-range zslider values (should never happen in principle)
            var z1 = Math.floor(parseFloat(z_input.value)*100) / 100
            if ( (Math.abs(zslider.value-z1) >= 0.01) &&
                 (zslider.value >= -0.1) && (zslider.value <= 5.0) ){
                 var new_z = zslider.value + dzslider.value
                 z_input.value = new_z.toFixed(4)
                }
            """)

        self.dzslider_callback  = CustomJS(
            args=dict(zslider=self.zslider, dzslider=self.dzslider, z_input=self.z_input),
            code="""
            var z = parseFloat(z_input.value)
            var z1 = Math.floor(z) / 100
            var z2 = z-z1
            if ( (Math.abs(dzslider.value-z2) >= 0.0001) &&
                 (dzslider.value >= 0.0) && (dzslider.value <= 0.0099) ){
                 var new_z = zslider.value + dzslider.value
                 z_input.value = new_z.toFixed(4)
                }
            """)

        self.zslider.js_on_change('value', self.zslider_callback)
        self.dzslider.js_on_change('value', self.dzslider_callback)

        self.z_minus_button = Button(label="<", width=self.z_button_width)
        self.z_plus_button = Button(label=">", width=self.z_button_width)
        self.z_minus_callback = CustomJS(
            args=dict(z_input=self.z_input),
            code="""
            var z = parseFloat(z_input.value)
            if(z >= -0.09) {
                z -= 0.01
                z_input.value = z.toFixed(4)
            }
            """)
        self.z_plus_callback = CustomJS(
            args=dict(z_input=self.z_input),
            code="""
            var z = parseFloat(z_input.value)
            if(z <= 4.99) {
                z += 0.01
                z_input.value = z.toFixed(4)
            }
            """)
        self.z_minus_button.js_on_event('button_click', self.z_minus_callback)
        self.z_plus_button.js_on_event('button_click', self.z_plus_callback)

        self.zreset_button = Button(label='Reset to z_pipe')
        self.zreset_callback = CustomJS(
            args=dict(z_input=self.z_input, metadata=viewer_cds.cds_metadata, ifiberslider=self.ifiberslider),
            code="""
                var ifiber = ifiberslider.value
                var z = metadata.data['Z'][ifiber]
                z_input.value = z.toFixed(4)
            """)
        self.zreset_button.js_on_event('button_click', self.zreset_callback)

        self.z_input_callback = CustomJS(
            args=dict(spectra = viewer_cds.cds_spectra,
                coaddcam_spec = viewer_cds.cds_coaddcam_spec,
                model = viewer_cds.cds_model,
                othermodel = viewer_cds.cds_othermodel,
                metadata = viewer_cds.cds_metadata,
                ifiberslider = self.ifiberslider,
                zslider = self.zslider,
                dzslider = self.dzslider,
                z_input = self.z_input,
                waveframe_buttons = self.waveframe_buttons,
                line_data = viewer_cds.cds_spectral_lines,
                lines = plots.speclines,
                line_labels = plots.specline_labels,
                zlines = plots.zoom_speclines,
                zline_labels = plots.zoom_specline_labels,
                overlap_waves = plots.overlap_waves,
                overlap_bands = plots.overlap_bands,
                fig = plots.fig
                ),
            code="""
                var z = parseFloat(z_input.value)
                if ( z >=-0.1 && z <= 5.0 ) {
                    // update zsliders only if needed (avoid recursive call)
                    z_input.value = parseFloat(z_input.value).toFixed(4)
                    var z1 = Math.floor(z*100) / 100
                    var z2 = z-z1
                    if ( Math.abs(z1-zslider.value) >= 0.01) zslider.value = parseFloat(parseFloat(z1).toFixed(2))
                    if ( Math.abs(z2-dzslider.value) >= 0.0001) dzslider.value = parseFloat(parseFloat(z2).toFixed(4))
                } else {
                    if (z_input.value < -0.1) z_input.value = (-0.1).toFixed(4)
                    if (z_input.value > 5) z_input.value = (5.0).toFixed(4)
                }

                var line_restwave = line_data.data['restwave']
                var ifiber = ifiberslider.value
                var waveshift_lines = (waveframe_buttons.active == 0) ? 1+z : 1 ;
                var waveshift_spec = (waveframe_buttons.active == 0) ? 1 : 1/(1+z) ;

                for(var i=0; i<line_restwave.length; i++) {
                    lines[i].location = line_restwave[i] * waveshift_lines
                    line_labels[i].x = line_restwave[i] * waveshift_lines
                    zlines[i].location = line_restwave[i] * waveshift_lines
                    zline_labels[i].x = line_restwave[i] * waveshift_lines
                }
                if (overlap_bands.length>0) {
                    for (var i=0; i<overlap_bands.length; i++) {
                        overlap_bands[i].left = overlap_waves[i][0] * waveshift_spec
                        overlap_bands[i].right = overlap_waves[i][1] * waveshift_spec
                    }
                }

                function shift_plotwave(cds_spec, waveshift) {
                    var data = cds_spec.data
                    var origwave = data['origwave']
                    var plotwave = data['plotwave']
                    if ( plotwave[0] != origwave[0] * waveshift ) { // Avoid redo calculation if not needed
                        for (var j=0; j<plotwave.length; j++) {
                            plotwave[j] = origwave[j] * waveshift ;
                        }
                        cds_spec.change.emit()
                    }
                }

                for(var i=0; i<spectra.length; i++) {
                    shift_plotwave(spectra[i], waveshift_spec)
                }
                if (coaddcam_spec) shift_plotwave(coaddcam_spec, waveshift_spec)

                // Update model wavelength array
                // NEW : don't shift model if othermodel is there
                if (othermodel) {
                    var zref = othermodel.data['zref'][0]
                    var waveshift_model = (waveframe_buttons.active == 0) ? (1+z)/(1+zref) : 1/(1+zref) ;
                    shift_plotwave(othermodel, waveshift_model)
                } else if (model) {
                    var zfit = 0.0
                    if(metadata.data['Z'] !== undefined) {
                        zfit = metadata.data['Z'][ifiber]
                    }
                    var waveshift_model = (waveframe_buttons.active == 0) ? (1+z)/(1+zfit) : 1/(1+zfit) ;
                    shift_plotwave(model, waveshift_model)
                }
            """)
        self.z_input.js_on_change('value', self.z_input_callback)
        self.waveframe_buttons.js_on_click(self.z_input_callback)

        self.plotrange_callback = CustomJS(
            args = dict(
                z_input=self.z_input,
                waveframe_buttons=self.waveframe_buttons,
                fig=plots.fig,
            ),
            code="""
            var z =  parseFloat(z_input.value)
            // Observer Frame
            if(waveframe_buttons.active == 0) {
                fig.x_range.start = fig.x_range.start * (1+z)
                fig.x_range.end = fig.x_range.end * (1+z)
            } else {
                fig.x_range.start = fig.x_range.start / (1+z)
                fig.x_range.end = fig.x_range.end / (1+z)
            }
            """
        )
        self.waveframe_buttons.js_on_click(self.plotrange_callback) # TODO: for record: is this related to waveframe bug? : 2 callbakcs for same click...


    def add_oii_widgets(self, plots):
        #------
        #- Zoom on the OII doublet TODO mv js code to other file
        # TODO: is there another trick than using a cds to pass the "oii_saveinfo" ?
        # TODO: optimize smoothing for autozoom (current value: 0)
        cds_oii_saveinfo = ColumnDataSource(
            {'xmin':[plots.fig.x_range.start], 'xmax':[plots.fig.x_range.end], 'nsmooth':[self.smootherslider.value]})
        self.oii_zoom_button = Button(label="OII-zoom", button_type="default")
        self.oii_zoom_callback = CustomJS(
            args = dict(z_input=self.z_input, fig=plots.fig, smootherslider=self.smootherslider,
                       cds_oii_saveinfo=cds_oii_saveinfo),
            code = """
            // Save previous setting (for the "Undo" button)
            cds_oii_saveinfo.data['xmin'] = [fig.x_range.start]
            cds_oii_saveinfo.data['xmax'] = [fig.x_range.end]
            cds_oii_saveinfo.data['nsmooth'] = [smootherslider.value]
            // Center on the middle of the redshifted OII doublet (vaccum)
            var z = parseFloat(z_input.value)
            fig.x_range.start = 3728.48 * (1+z) - 100
            fig.x_range.end = 3728.48 * (1+z) + 100
            // No smoothing (this implies a call to update_plot)
            smootherslider.value = 0
            """)
        self.oii_zoom_button.js_on_event('button_click', self.oii_zoom_callback)

        self.oii_undo_button = Button(label="Undo OII-zoom", button_type="default")
        self.oii_undo_callback = CustomJS(
            args = dict(fig=plots.fig, smootherslider=self.smootherslider, cds_oii_saveinfo=cds_oii_saveinfo),
            code = """
            fig.x_range.start = cds_oii_saveinfo.data['xmin'][0]
            fig.x_range.end = cds_oii_saveinfo.data['xmax'][0]
            smootherslider.value = cds_oii_saveinfo.data['nsmooth'][0]
            """)
        self.oii_undo_button.js_on_event('button_click', self.oii_undo_callback)


    def add_coaddcam(self, plots):
        #-----
        #- Highlight individual-arm or camera-coadded spectra
        coaddcam_labels = ["Camera-coadded", "Single-arm"]
        self.coaddcam_buttons = RadioButtonGroup(labels=coaddcam_labels, active=0)
        self.coaddcam_callback = CustomJS(
            args = dict(coaddcam_buttons = self.coaddcam_buttons,
                        list_lines=[plots.data_lines, plots.noise_lines,
                                    plots.zoom_data_lines, plots.zoom_noise_lines],
                        alpha_discrete = plots.alpha_discrete,
                        overlap_bands = plots.overlap_bands,
                        alpha_overlapband = plots.alpha_overlapband),
            code="""
            var n_lines = list_lines[0].length
            for (var i=0; i<n_lines; i++) {
                var new_alpha = 1
                if (coaddcam_buttons.active == 0 && i<n_lines-1) new_alpha = alpha_discrete
                if (coaddcam_buttons.active == 1 && i==n_lines-1) new_alpha = alpha_discrete
                for (var j=0; j<list_lines.length; j++) {
                    list_lines[j][i].glyph.line_alpha = new_alpha
                }
            }
            var new_alpha = 0
            if (coaddcam_buttons.active == 0) new_alpha = alpha_overlapband
            for (var j=0; j<overlap_bands.length; j++) {
                    overlap_bands[j].fill_alpha = new_alpha
            }
            """
        )
        self.coaddcam_buttons.js_on_click(self.coaddcam_callback)
    
    
    def add_metadata_tables(self, viewer_cds, show_zcat=True, template_dicts=None,
                           top_metadata=['TARGETID', 'EXPID']):
        """ Display object-related informations
                top_metadata: metadata to be highlighted in table_a
            
            Note: "short" CDS, with a single row, are used to fill these bokeh tables.
            When changing object, js code modifies these short CDS so that tables are updated.  
        """

        #- Sorted list of potential metadata:
        metadata_to_check = ['TARGETID', 'HPXPIXEL', 'TILEID', 'COADD_NUMEXP', 'COADD_EXPTIME', 
                             'NIGHT', 'EXPID', 'FIBER', 'CAMERA', 'MORPHTYPE']
        metadata_to_check += [ ('mag_'+x) for x in viewer_cds.phot_bands ]
        table_keys = []
        for key in metadata_to_check:
            if key in viewer_cds.cds_metadata.data.keys():
                table_keys.append(key)
            if 'NUM_'+key in viewer_cds.cds_metadata.data.keys():
                for prefix in ['FIRST','LAST','NUM']:
                    table_keys.append(prefix+'_'+key)
                    if key in top_metadata:
                        top_metadata.append(prefix+'_'+key)
        
        #- Table a: "top metadata"
        table_a_keys = [ x for x in table_keys if x in top_metadata ]
        self.shortcds_table_a, self.table_a = _metadata_table(table_a_keys, viewer_cds, table_width=600, 
                                                              shortcds_name='shortcds_table_a', selectable=True)
        #- Table b: Targeting information
        self.shortcds_table_b, self.table_b = _metadata_table(['Targeting masks'], viewer_cds, table_width=self.plot_widget_width,
                                                              shortcds_name='shortcds_table_b', selectable=True)
        #- Table(s) c/d : Other information (imaging, etc.)
        remaining_keys = [ x for x in table_keys if x not in top_metadata ]
        if len(remaining_keys) > 7:
            table_c_keys = remaining_keys[0:len(remaining_keys)//2]
            table_d_keys = remaining_keys[len(remaining_keys)//2:]
        else:
            table_c_keys = remaining_keys
            table_d_keys = None
        self.shortcds_table_c, self.table_c = _metadata_table(table_c_keys, viewer_cds, table_width=self.plot_widget_width,
                                                             shortcds_name='shortcds_table_c', selectable=False)
        if table_d_keys is None:
            self.shortcds_table_d, self.table_d = None, None
        else:
            self.shortcds_table_d, self.table_d = _metadata_table(table_d_keys, viewer_cds, table_width=self.plot_widget_width,
                                                                 shortcds_name='shortcds_table_d', selectable=False)

        #- Table z: redshift fitting information
        if show_zcat is not None :
            if template_dicts is not None : # Add other best fits
                fit_results = template_dicts[1]
                # Case of DeltaChi2 : compute it from Chi2s
                #    The "DeltaChi2" in rr fits is between best fits for a given (spectype,subtype)
                #    Convention: DeltaChi2 = -1 for the last fit.
                chi2s = fit_results['CHI2'][0]
                full_deltachi2s = np.zeros(len(chi2s))-1
                full_deltachi2s[:-1] = chi2s[1:]-chi2s[:-1]
                cdsdata = dict(Nfit = np.arange(1,len(chi2s)+1),
                                SPECTYPE = fit_results['SPECTYPE'][0],  # [0:num_best_fits] (if we want to restrict... TODO?)
                                SUBTYPE = fit_results['SUBTYPE'][0],
                                Z = [ "{:.4f}".format(x) for x in fit_results['Z'][0] ],
                                ZERR = [ "{:.4f}".format(x) for x in fit_results['ZERR'][0] ],
                                ZWARN = fit_results['ZWARN'][0],
                                CHI2 = [ "{:.1f}".format(x) for x in fit_results['CHI2'][0] ],
                                DELTACHI2 = [ "{:.1f}".format(x) for x in full_deltachi2s ])
                self.shortcds_table_z = ColumnDataSource(cdsdata, name='shortcds_table_z')
                columns_table_z = [ TableColumn(field=x, title=t, width=w) for x,t,w in [ ('Nfit','Nfit',5), ('SPECTYPE','SPECTYPE',70), ('SUBTYPE','SUBTYPE',60), ('Z','Z',50) , ('ZERR','ZERR',50), ('ZWARN','ZWARN',50), ('DELTACHI2','Δχ2(N+1/N)',70)] ]
                self.table_z = DataTable(source=self.shortcds_table_z, columns=columns_table_z,
                                         selectable=False, index_position=None, width=self.plot_widget_width)
                self.table_z.height = 3 * self.table_z.row_height
            else :
                self.shortcds_table_z, self.table_z = _metadata_table(viewer_cds.zcat_keys, viewer_cds,
                                    table_width=self.plot_widget_width, shortcds_name='shortcds_table_z', selectable=False)
        else :
            self.table_z = Div(text="Not available ")
            self.shortcds_table_z = None


    def add_specline_toggles(self, viewer_cds, plots):
        #-----
        #- Toggle lines
        self.speclines_button_group = CheckboxButtonGroup(
                labels=["Emission lines", "Absorption lines"], active=[])
        self.majorline_checkbox = CheckboxGroup(
                labels=['Show only major lines'], active=[])

        self.speclines_callback = CustomJS(
            args = dict(line_data = viewer_cds.cds_spectral_lines,
                        lines = plots.speclines,
                        line_labels = plots.specline_labels,
                        zlines = plots.zoom_speclines,
                        zline_labels = plots.zoom_specline_labels,
                        lines_button_group = self.speclines_button_group,
                        majorline_checkbox = self.majorline_checkbox),
            code="""
            var show_emission = false
            var show_absorption = false
            if (lines_button_group.active.indexOf(0) >= 0) {  // index 0=Emission in active list
                show_emission = true
            }
            if (lines_button_group.active.indexOf(1) >= 0) {  // index 1=Absorption in active list
                show_absorption = true
            }

            for(var i=0; i<lines.length; i++) {
                if ( !(line_data.data['major'][i]) && (majorline_checkbox.active.indexOf(0)>=0) ) {
                    lines[i].visible = false
                    line_labels[i].visible = false
                    zlines[i].visible = false
                    zline_labels[i].visible = false
                } else if (line_data.data['emission'][i]) {
                    lines[i].visible = show_emission
                    line_labels[i].visible = show_emission
                    zlines[i].visible = show_emission
                    zline_labels[i].visible = show_emission
                } else {
                    lines[i].visible = show_absorption
                    line_labels[i].visible = show_absorption
                    zlines[i].visible = show_absorption
                    zline_labels[i].visible = show_absorption
                }
            }
            """
        )
        self.speclines_button_group.js_on_click(self.speclines_callback)
        self.majorline_checkbox.js_on_click(self.speclines_callback)


    def add_model_select(self, viewer_cds, template_dicts, num_approx_fits, with_full_2ndfit=True):
        #------
        #- Select secondary model to display
        model_options = ['Best fit', '2nd best fit']
        for i in range(1,1+num_approx_fits) :
            ith = 'th'
            if i==1 : ith='st'
            if i==2 : ith='nd'
            if i==3 : ith='rd'
            model_options.append(str(i)+ith+' fit (approx)')
        if with_full_2ndfit is False :
            model_options.remove('2nd best fit')
        for std_template in ['QSO', 'GALAXY', 'STAR'] :
            model_options.append('STD '+std_template)
        self.model_select = Select(value=model_options[0], title="Other model (dashed curve):", options=model_options)
        model_select_code = self.js_files["interp_grid.js"] + self.js_files["smooth_data.js"] + self.js_files["select_model.js"]
        self.model_select_callback = CustomJS(
            args = dict(ifiberslider = self.ifiberslider,
                        model_select = self.model_select,
                        fit_templates=template_dicts[0],
                        cds_othermodel = viewer_cds.cds_othermodel,
                        cds_model_2ndfit = viewer_cds.cds_model_2ndfit,
                        cds_model = viewer_cds.cds_model,
                        fit_results=template_dicts[1],
                        std_templates=template_dicts[2],
                        median_spectra = viewer_cds.cds_median_spectra,
                        smootherslider = self.smootherslider,
                        z_input = self.z_input,
                        cds_metadata = viewer_cds.cds_metadata),
                        code = model_select_code)
        self.model_select.js_on_change('value', self.model_select_callback)


    def add_update_plot_callback(self, viewer_cds, plots, vi_widgets, template_dicts):
        #-----
        #- Main js code to update plots
        update_plot_code = (self.js_files["adapt_plotrange.js"] + self.js_files["interp_grid.js"] +
                            self.js_files["smooth_data.js"] + self.js_files["coadd_brz_cameras.js"] +
                            self.js_files["update_plot.js"])
        # TMP handling of template_dicts
        the_fit_results = None if template_dicts is None else template_dicts[1] # dirty
        self.update_plot_callback = CustomJS(
            args = dict(
                spectra = viewer_cds.cds_spectra,
                coaddcam_spec = viewer_cds.cds_coaddcam_spec,
                model = viewer_cds.cds_model,
                othermodel = viewer_cds.cds_othermodel,
                model_2ndfit = viewer_cds.cds_model_2ndfit,
                metadata = viewer_cds.cds_metadata,
                fit_results = the_fit_results,
                shortcds_table_z = self.shortcds_table_z,
                shortcds_table_a = self.shortcds_table_a,
                shortcds_table_b = self.shortcds_table_b,
                shortcds_table_c = self.shortcds_table_c,
                shortcds_table_d = self.shortcds_table_d,
                ifiberslider = self.ifiberslider,
                smootherslider = self.smootherslider,
                z_input = self.z_input,
                fig = plots.fig,
                imfig_source = plots.imfig_source,
                imfig_urls = plots.imfig_urls,
                model_select = self.model_select,
                vi_comment_input = vi_widgets.vi_comment_input,
                vi_std_comment_select = vi_widgets.vi_std_comment_select,
                vi_name_input = vi_widgets.vi_name_input,
                vi_quality_input = vi_widgets.vi_quality_input,
                vi_quality_labels = vi_widgets.vi_quality_labels,
                vi_issue_input = vi_widgets.vi_issue_input,
                vi_z_input = vi_widgets.vi_z_input,
                vi_category_select = vi_widgets.vi_category_select,
                vi_issue_slabels = vi_widgets.vi_issue_slabels
                ),
            code = update_plot_code
        )
        self.smootherslider.js_on_change('value', self.update_plot_callback)
        self.ifiberslider.js_on_change('value', self.update_plot_callback)