コード例 #1
0
ファイル: plot.py プロジェクト: Yashvi-Sharma/skyportal
def photometry_plot(obj_id, user, width=600, device="browser"):
    """Create object photometry scatter plot.

    Parameters
    ----------
    obj_id : str
        ID of Obj to be plotted.

    Returns
    -------
    dict
        Returns Bokeh JSON embedding for the desired plot.
    """

    data = pd.read_sql(
        DBSession()
        .query(
            Photometry,
            Telescope.nickname.label("telescope"),
            Instrument.name.label("instrument"),
        )
        .join(Instrument, Instrument.id == Photometry.instrument_id)
        .join(Telescope, Telescope.id == Instrument.telescope_id)
        .filter(Photometry.obj_id == obj_id)
        .filter(
            Photometry.groups.any(Group.id.in_([g.id for g in user.accessible_groups]))
        )
        .statement,
        DBSession().bind,
    )

    if data.empty:
        return None, None, None

    # get spectra to annotate on phot plots
    spectra = (
        Spectrum.query_records_accessible_by(user)
        .filter(Spectrum.obj_id == obj_id)
        .all()
    )

    data['color'] = [get_color(f) for f in data['filter']]

    # get marker for each unique instrument
    instruments = list(data.instrument.unique())
    markers = []
    for i, inst in enumerate(instruments):
        markers.append(phot_markers[i % len(phot_markers)])

    filters = list(set(data['filter']))
    colors = [get_color(f) for f in filters]

    color_mapper = CategoricalColorMapper(factors=filters, palette=colors)
    color_dict = {'field': 'filter', 'transform': color_mapper}

    labels = []
    for i, datarow in data.iterrows():
        label = f'{datarow["instrument"]}/{datarow["filter"]}'
        if datarow['origin'] is not None:
            label += f'/{datarow["origin"]}'
        labels.append(label)

    data['label'] = labels
    data['zp'] = PHOT_ZP
    data['magsys'] = 'ab'
    data['alpha'] = 1.0
    data['lim_mag'] = (
        -2.5 * np.log10(data['fluxerr'] * PHOT_DETECTION_THRESHOLD) + data['zp']
    )

    # Passing a dictionary to a bokeh datasource causes the frontend to die,
    # deleting the dictionary column fixes that
    del data['original_user_data']

    # keep track of things that are only upper limits
    data['hasflux'] = ~data['flux'].isna()

    # calculate the magnitudes - a photometry point is considered "significant"
    # or "detected" (and thus can be represented by a magnitude) if its snr
    # is above PHOT_DETECTION_THRESHOLD
    obsind = data['hasflux'] & (
        data['flux'].fillna(0.0) / data['fluxerr'] >= PHOT_DETECTION_THRESHOLD
    )
    data.loc[~obsind, 'mag'] = None
    data.loc[obsind, 'mag'] = -2.5 * np.log10(data[obsind]['flux']) + PHOT_ZP

    # calculate the magnitude errors using standard error propagation formulae
    # https://en.wikipedia.org/wiki/Propagation_of_uncertainty#Example_formulae
    data.loc[~obsind, 'magerr'] = None
    coeff = 2.5 / np.log(10)
    magerrs = np.abs(coeff * data[obsind]['fluxerr'] / data[obsind]['flux'])
    data.loc[obsind, 'magerr'] = magerrs
    data['obs'] = obsind
    data['stacked'] = False

    split = data.groupby('label', sort=False)

    finite = np.isfinite(data['flux'])
    fdata = data[finite]
    lower = np.min(fdata['flux']) * 0.95
    upper = np.max(fdata['flux']) * 1.05

    xmin = data['mjd'].min() - 2
    xmax = data['mjd'].max() + 2

    # Layout parameters based on device type
    active_drag = None if "mobile" in device or "tablet" in device else "box_zoom"
    tools = (
        'box_zoom,pan,reset'
        if "mobile" in device or "tablet" in device
        else "box_zoom,wheel_zoom,pan,reset,save"
    )
    legend_loc = "below" if "mobile" in device or "tablet" in device else "right"
    legend_orientation = (
        "vertical" if device in ["browser", "mobile_portrait"] else "horizontal"
    )

    # Compute a plot component height based on rough number of legend rows added below the plot
    # Values are based on default sizing of bokeh components and an estimate of how many
    # legend items would fit on the average device screen. Note that the legend items per
    # row is computed more exactly later once labels are extracted from the data (with the
    # add_plot_legend() function).
    #
    # The height is manually computed like this instead of using built in aspect_ratio/sizing options
    # because with the new Interactive Legend approach (instead of the legacy CheckboxLegendGroup), the
    # Legend component is considered part of the plot and plays into the sizing computations. Since the
    # number of items in the legend can alter the needed heights of the plot, using built-in Bokeh options
    # for sizing does not allow for keeping the actual graph part of the plot at a consistent aspect ratio.
    #
    # For the frame width, by default we take the desired plot width minus 64 for the y-axis/label taking
    # up horizontal space
    frame_width = width - 64
    if device == "mobile_portrait":
        legend_items_per_row = 1
        legend_row_height = 24
        aspect_ratio = 1
    elif device == "mobile_landscape":
        legend_items_per_row = 4
        legend_row_height = 50
        aspect_ratio = 1.8
    elif device == "tablet_portrait":
        legend_items_per_row = 5
        legend_row_height = 50
        aspect_ratio = 1.5
    elif device == "tablet_landscape":
        legend_items_per_row = 7
        legend_row_height = 50
        aspect_ratio = 1.8
    elif device == "browser":
        # Width minus some base width for the legend, which is only a column to the right
        # for browser mode
        frame_width = width - 200

    height = (
        500
        if device == "browser"
        else math.floor(width / aspect_ratio)
        + legend_row_height * int(len(split) / legend_items_per_row)
        + 30  # 30 is the height of the toolbar
    )

    plot = figure(
        frame_width=frame_width,
        height=height,
        active_drag=active_drag,
        tools=tools,
        toolbar_location='above',
        toolbar_sticky=True,
        y_range=(lower, upper),
        min_border_right=16,
        x_axis_location='above',
        sizing_mode="stretch_width",
    )

    plot.xaxis.axis_label = 'MJD'
    now = Time.now().mjd
    plot.extra_x_ranges = {"Days Ago": Range1d(start=now - xmin, end=now - xmax)}
    plot.add_layout(LinearAxis(x_range_name="Days Ago", axis_label="Days Ago"), 'below')

    imhover = HoverTool(tooltips=tooltip_format)
    imhover.renderers = []
    plot.add_tools(imhover)

    model_dict = {}
    legend_items = []
    for i, (label, sdf) in enumerate(split):
        renderers = []

        # for the flux plot, we only show things that have a flux value
        df = sdf[sdf['hasflux']]

        key = f'obs{i}'
        model_dict[key] = plot.scatter(
            x='mjd',
            y='flux',
            color='color',
            marker=factor_mark('instrument', markers, instruments),
            fill_color=color_dict,
            alpha='alpha',
            source=ColumnDataSource(df),
        )
        renderers.append(model_dict[key])
        imhover.renderers.append(model_dict[key])

        key = f'bin{i}'
        model_dict[key] = plot.scatter(
            x='mjd',
            y='flux',
            color='color',
            marker=factor_mark('instrument', markers, instruments),
            fill_color=color_dict,
            source=ColumnDataSource(
                data=dict(
                    mjd=[],
                    flux=[],
                    fluxerr=[],
                    filter=[],
                    color=[],
                    lim_mag=[],
                    mag=[],
                    magerr=[],
                    stacked=[],
                    instrument=[],
                )
            ),
        )
        renderers.append(model_dict[key])
        imhover.renderers.append(model_dict[key])

        key = 'obserr' + str(i)
        y_err_x = []
        y_err_y = []

        for d, ro in df.iterrows():
            px = ro['mjd']
            py = ro['flux']
            err = ro['fluxerr']

            y_err_x.append((px, px))
            y_err_y.append((py - err, py + err))

        model_dict[key] = plot.multi_line(
            xs='xs',
            ys='ys',
            color='color',
            alpha='alpha',
            source=ColumnDataSource(
                data=dict(
                    xs=y_err_x, ys=y_err_y, color=df['color'], alpha=[1.0] * len(df)
                )
            ),
        )
        renderers.append(model_dict[key])

        key = f'binerr{i}'
        model_dict[key] = plot.multi_line(
            xs='xs',
            ys='ys',
            color='color',
            # legend_label=label,
            source=ColumnDataSource(data=dict(xs=[], ys=[], color=[])),
        )
        renderers.append(model_dict[key])

        legend_items.append(LegendItem(label=label, renderers=renderers))

    if device == "mobile_portrait":
        plot.xaxis.ticker.desired_num_ticks = 5
    plot.yaxis.axis_label = 'Flux (μJy)'
    plot.toolbar.logo = None

    add_plot_legend(plot, legend_items, width, legend_orientation, legend_loc)
    slider = Slider(
        start=0.0,
        end=15.0,
        value=0.0,
        step=1.0,
        title='Binsize (days)',
        max_width=350,
        margin=(4, 10, 0, 10),
    )

    callback = CustomJS(
        args={'slider': slider, 'n_labels': len(split), **model_dict},
        code=open(
            os.path.join(os.path.dirname(__file__), '../static/js/plotjs', 'stackf.js')
        )
        .read()
        .replace('default_zp', str(PHOT_ZP))
        .replace('detect_thresh', str(PHOT_DETECTION_THRESHOLD)),
    )

    slider.js_on_change('value', callback)

    # Mark the first and last detections
    detection_dates = data[data['hasflux']]['mjd']
    if len(detection_dates) > 0:
        first = round(detection_dates.min(), 6)
        last = round(detection_dates.max(), 6)
        first_color = "#34b4eb"
        last_color = "#8992f5"
        midpoint = (upper + lower) / 2
        line_top = 5 * upper - 4 * midpoint
        line_bottom = 5 * lower - 4 * midpoint
        y = np.linspace(line_bottom, line_top, num=5000)
        first_r = plot.line(
            x=np.full(5000, first),
            y=y,
            line_alpha=0.5,
            line_color=first_color,
            line_width=2,
        )
        plot.add_tools(
            HoverTool(
                tooltips=[("First detection", f'{first}')],
                renderers=[first_r],
            )
        )
        last_r = plot.line(
            x=np.full(5000, last),
            y=y,
            line_alpha=0.5,
            line_color=last_color,
            line_width=2,
        )
        plot.add_tools(
            HoverTool(
                tooltips=[("Last detection", f'{last}')],
                renderers=[last_r],
            )
        )

    # Mark when spectra were taken
    annotate_spec(plot, spectra, lower, upper)
    layout = column(slider, plot, width=width, height=height)

    p1 = Panel(child=layout, title='Flux')

    # now make the mag light curve
    ymax = (
        np.nanmax(
            (
                np.nanmax(data.loc[obsind, 'mag']) if any(obsind) else np.nan,
                np.nanmax(data.loc[~obsind, 'lim_mag']) if any(~obsind) else np.nan,
            )
        )
        + 0.1
    )
    ymin = (
        np.nanmin(
            (
                np.nanmin(data.loc[obsind, 'mag']) if any(obsind) else np.nan,
                np.nanmin(data.loc[~obsind, 'lim_mag']) if any(~obsind) else np.nan,
            )
        )
        - 0.1
    )

    plot = figure(
        frame_width=frame_width,
        height=height,
        active_drag=active_drag,
        tools=tools,
        y_range=(ymax, ymin),
        x_range=(xmin, xmax),
        toolbar_location='above',
        toolbar_sticky=True,
        x_axis_location='above',
        sizing_mode="stretch_width",
    )

    plot.xaxis.axis_label = 'MJD'
    now = Time.now().mjd
    plot.extra_x_ranges = {"Days Ago": Range1d(start=now - xmin, end=now - xmax)}
    plot.add_layout(LinearAxis(x_range_name="Days Ago", axis_label="Days Ago"), 'below')

    obj = DBSession().query(Obj).get(obj_id)
    if obj.dm is not None:
        plot.extra_y_ranges = {
            "Absolute Mag": Range1d(start=ymax - obj.dm, end=ymin - obj.dm)
        }
        plot.add_layout(
            LinearAxis(y_range_name="Absolute Mag", axis_label="m - DM"), 'right'
        )

    # Mark the first and last detections again
    detection_dates = data[obsind]['mjd']
    if len(detection_dates) > 0:
        first = round(detection_dates.min(), 6)
        last = round(detection_dates.max(), 6)
        midpoint = (ymax + ymin) / 2
        line_top = 5 * ymax - 4 * midpoint
        line_bottom = 5 * ymin - 4 * midpoint
        y = np.linspace(line_bottom, line_top, num=5000)
        first_r = plot.line(
            x=np.full(5000, first),
            y=y,
            line_alpha=0.5,
            line_color=first_color,
            line_width=2,
        )
        plot.add_tools(
            HoverTool(
                tooltips=[("First detection", f'{first}')],
                renderers=[first_r],
            )
        )
        last_r = plot.line(
            x=np.full(5000, last),
            y=y,
            line_alpha=0.5,
            line_color=last_color,
            line_width=2,
        )
        plot.add_tools(
            HoverTool(
                tooltips=[("Last detection", f'{last}')],
                renderers=[last_r],
                point_policy='follow_mouse',
            )
        )

    # Mark when spectra were taken
    annotate_spec(plot, spectra, ymax, ymin)

    imhover = HoverTool(tooltips=tooltip_format)
    imhover.renderers = []
    plot.add_tools(imhover)

    model_dict = {}

    # Legend items are individually stored instead of being applied
    # directly when plotting so that they can be separated into multiple
    # Legend() components if needed (to simulate horizontal row wrapping).
    # This is necessary because Bokeh does not support row wrapping with
    # horizontally-oriented legends out-of-the-box.
    legend_items = []
    for i, (label, df) in enumerate(split):
        renderers = []

        key = f'obs{i}'
        model_dict[key] = plot.scatter(
            x='mjd',
            y='mag',
            color='color',
            marker=factor_mark('instrument', markers, instruments),
            fill_color=color_dict,
            alpha='alpha',
            source=ColumnDataSource(df[df['obs']]),
        )
        renderers.append(model_dict[key])
        imhover.renderers.append(model_dict[key])

        unobs_source = df[~df['obs']].copy()
        unobs_source.loc[:, 'alpha'] = 0.8

        key = f'unobs{i}'
        model_dict[key] = plot.scatter(
            x='mjd',
            y='lim_mag',
            color=color_dict,
            marker='inverted_triangle',
            fill_color='white',
            line_color='color',
            alpha='alpha',
            source=ColumnDataSource(unobs_source),
        )
        renderers.append(model_dict[key])
        imhover.renderers.append(model_dict[key])

        key = f'bin{i}'
        model_dict[key] = plot.scatter(
            x='mjd',
            y='mag',
            color=color_dict,
            marker=factor_mark('instrument', markers, instruments),
            fill_color='color',
            source=ColumnDataSource(
                data=dict(
                    mjd=[],
                    flux=[],
                    fluxerr=[],
                    filter=[],
                    color=[],
                    lim_mag=[],
                    mag=[],
                    magerr=[],
                    instrument=[],
                    stacked=[],
                )
            ),
        )
        renderers.append(model_dict[key])
        imhover.renderers.append(model_dict[key])

        key = 'obserr' + str(i)
        y_err_x = []
        y_err_y = []

        for d, ro in df[df['obs']].iterrows():
            px = ro['mjd']
            py = ro['mag']
            err = ro['magerr']

            y_err_x.append((px, px))
            y_err_y.append((py - err, py + err))

        model_dict[key] = plot.multi_line(
            xs='xs',
            ys='ys',
            color='color',
            alpha='alpha',
            source=ColumnDataSource(
                data=dict(
                    xs=y_err_x,
                    ys=y_err_y,
                    color=df[df['obs']]['color'],
                    alpha=[1.0] * len(df[df['obs']]),
                )
            ),
        )
        renderers.append(model_dict[key])

        key = f'binerr{i}'
        model_dict[key] = plot.multi_line(
            xs='xs',
            ys='ys',
            color='color',
            source=ColumnDataSource(data=dict(xs=[], ys=[], color=[])),
        )
        renderers.append(model_dict[key])

        key = f'unobsbin{i}'
        model_dict[key] = plot.scatter(
            x='mjd',
            y='lim_mag',
            color='color',
            marker='inverted_triangle',
            fill_color='white',
            line_color=color_dict,
            alpha=0.8,
            source=ColumnDataSource(
                data=dict(
                    mjd=[],
                    flux=[],
                    fluxerr=[],
                    filter=[],
                    color=[],
                    lim_mag=[],
                    mag=[],
                    magerr=[],
                    instrument=[],
                    stacked=[],
                )
            ),
        )
        imhover.renderers.append(model_dict[key])
        renderers.append(model_dict[key])

        key = f'all{i}'
        model_dict[key] = ColumnDataSource(df)

        key = f'bold{i}'
        model_dict[key] = ColumnDataSource(
            df[
                [
                    'mjd',
                    'flux',
                    'fluxerr',
                    'mag',
                    'magerr',
                    'filter',
                    'zp',
                    'magsys',
                    'lim_mag',
                    'stacked',
                ]
            ]
        )

        legend_items.append(LegendItem(label=label, renderers=renderers))

    add_plot_legend(plot, legend_items, width, legend_orientation, legend_loc)

    plot.yaxis.axis_label = 'AB mag'
    plot.toolbar.logo = None

    slider = Slider(
        start=0.0,
        end=15.0,
        value=0.0,
        step=1.0,
        title='Binsize (days)',
        max_width=350,
        margin=(4, 10, 0, 10),
    )

    button = Button(label="Export Bold Light Curve to CSV")
    button.js_on_click(
        CustomJS(
            args={'slider': slider, 'n_labels': len(split), **model_dict},
            code=open(
                os.path.join(
                    os.path.dirname(__file__), '../static/js/plotjs', "download.js"
                )
            )
            .read()
            .replace('objname', obj_id)
            .replace('default_zp', str(PHOT_ZP)),
        )
    )

    # Don't need to expose CSV download on mobile
    top_layout = (
        slider if "mobile" in device or "tablet" in device else row(slider, button)
    )

    callback = CustomJS(
        args={'slider': slider, 'n_labels': len(split), **model_dict},
        code=open(
            os.path.join(os.path.dirname(__file__), '../static/js/plotjs', 'stackm.js')
        )
        .read()
        .replace('default_zp', str(PHOT_ZP))
        .replace('detect_thresh', str(PHOT_DETECTION_THRESHOLD)),
    )
    slider.js_on_change('value', callback)
    layout = column(top_layout, plot, width=width, height=height)

    p2 = Panel(child=layout, title='Mag')

    # now make period plot

    # get periods from annotations
    annotation_list = obj.get_annotations_readable_by(user)
    period_labels = []
    period_list = []
    for an in annotation_list:
        if 'period' in an.data:
            period_list.append(an.data['period'])
            period_labels.append(an.origin + ": %.9f" % an.data['period'])

    if len(period_list) > 0:
        period = period_list[0]
    else:
        period = None
    # don't generate if no period annotated
    if period is not None:
        # bokeh figure for period plotting
        period_plot = figure(
            frame_width=frame_width,
            height=height,
            active_drag=active_drag,
            tools=tools,
            y_range=(ymax, ymin),
            x_range=(-0.01, 2.01),  # initially one phase
            toolbar_location='above',
            toolbar_sticky=False,
            x_axis_location='below',
            sizing_mode="stretch_width",
        )

        # axis labels
        period_plot.xaxis.axis_label = 'phase'
        period_plot.yaxis.axis_label = 'mag'
        period_plot.toolbar.logo = None

        # do we have a distance modulus (dm)?
        obj = DBSession().query(Obj).get(obj_id)
        if obj.dm is not None:
            period_plot.extra_y_ranges = {
                "Absolute Mag": Range1d(start=ymax - obj.dm, end=ymin - obj.dm)
            }
            period_plot.add_layout(
                LinearAxis(y_range_name="Absolute Mag", axis_label="m - DM"), 'right'
            )

        # initiate hover tool
        period_imhover = HoverTool(tooltips=tooltip_format)
        period_imhover.renderers = []
        period_plot.add_tools(period_imhover)

        # initiate period radio buttons
        period_selection = RadioGroup(labels=period_labels, active=0)

        phase_selection = RadioGroup(labels=["One phase", "Two phases"], active=1)

        # store all the plot data
        period_model_dict = {}

        # iterate over each filter
        legend_items = []
        for i, (label, df) in enumerate(split):
            renderers = []
            # fold x-axis on period in days
            df['mjd_folda'] = (df['mjd'] % period) / period
            df['mjd_foldb'] = df['mjd_folda'] + 1.0

            # phase plotting
            for ph in ['a', 'b']:
                key = 'fold' + ph + f'{i}'
                period_model_dict[key] = period_plot.scatter(
                    x='mjd_fold' + ph,
                    y='mag',
                    color='color',
                    marker=factor_mark('instrument', markers, instruments),
                    fill_color=color_dict,
                    alpha='alpha',
                    # visible=('a' in ph),
                    source=ColumnDataSource(df[df['obs']]),  # only visible data
                )
                # add to hover tool
                period_imhover.renderers.append(period_model_dict[key])
                renderers.append(period_model_dict[key])

                # errorbars for phases
                key = 'fold' + ph + f'err{i}'
                y_err_x = []
                y_err_y = []

                # get each visible error value
                for d, ro in df[df['obs']].iterrows():
                    px = ro['mjd_fold' + ph]
                    py = ro['mag']
                    err = ro['magerr']
                    # set up error tuples
                    y_err_x.append((px, px))
                    y_err_y.append((py - err, py + err))
                # plot phase errors
                period_model_dict[key] = period_plot.multi_line(
                    xs='xs',
                    ys='ys',
                    color='color',
                    alpha='alpha',
                    # visible=('a' in ph),
                    source=ColumnDataSource(
                        data=dict(
                            xs=y_err_x,
                            ys=y_err_y,
                            color=df[df['obs']]['color'],
                            alpha=[1.0] * len(df[df['obs']]),
                        )
                    ),
                )
                renderers.append(period_model_dict[key])

            legend_items.append(LegendItem(label=label, renderers=renderers))

        add_plot_legend(
            period_plot, legend_items, width, legend_orientation, legend_loc
        )

        # set up period adjustment text box
        period_title = Div(text="Period (days): ")
        period_textinput = TextInput(value=str(period if period is not None else 0.0))
        period_textinput.js_on_change(
            'value',
            CustomJS(
                args={
                    'textinput': period_textinput,
                    'numphases': phase_selection,
                    'n_labels': len(split),
                    'p': period_plot,
                    **period_model_dict,
                },
                code=open(
                    os.path.join(
                        os.path.dirname(__file__), '../static/js/plotjs', 'foldphase.js'
                    )
                ).read(),
            ),
        )
        # a way to modify the period
        period_double_button = Button(label="*2", width=30)
        period_double_button.js_on_click(
            CustomJS(
                args={'textinput': period_textinput},
                code="""
                const period = parseFloat(textinput.value);
                textinput.value = parseFloat(2.*period).toFixed(9);
                """,
            )
        )
        period_halve_button = Button(label="/2", width=30)
        period_halve_button.js_on_click(
            CustomJS(
                args={'textinput': period_textinput},
                code="""
                        const period = parseFloat(textinput.value);
                        textinput.value = parseFloat(period/2.).toFixed(9);
                        """,
            )
        )
        # a way to select the period
        period_selection.js_on_click(
            CustomJS(
                args={'textinput': period_textinput, 'periods': period_list},
                code="""
                textinput.value = parseFloat(periods[this.active]).toFixed(9);
                """,
            )
        )
        phase_selection.js_on_click(
            CustomJS(
                args={
                    'textinput': period_textinput,
                    'numphases': phase_selection,
                    'n_labels': len(split),
                    'p': period_plot,
                    **period_model_dict,
                },
                code=open(
                    os.path.join(
                        os.path.dirname(__file__), '../static/js/plotjs', 'foldphase.js'
                    )
                ).read(),
            )
        )

        # layout
        if device == "mobile_portrait":
            period_controls = column(
                row(
                    period_title,
                    period_textinput,
                    period_double_button,
                    period_halve_button,
                    width=width,
                    sizing_mode="scale_width",
                ),
                phase_selection,
                period_selection,
                width=width,
            )
            # Add extra height to plot based on period control components added
            # 18 is the height of each period selection radio option (per default font size)
            # and the 130 encompasses the other components which are consistent no matter
            # the data size.
            height += 130 + 18 * len(period_labels)
        else:
            period_controls = column(
                row(
                    period_title,
                    period_textinput,
                    period_double_button,
                    period_halve_button,
                    phase_selection,
                    width=width,
                    sizing_mode="scale_width",
                ),
                period_selection,
                margin=10,
            )
            # Add extra height to plot based on period control components added
            # Numbers are derived in similar manner to the "mobile_portrait" case above
            height += 90 + 18 * len(period_labels)

        period_layout = column(period_plot, period_controls, width=width, height=height)

        # Period panel
        p3 = Panel(child=period_layout, title='Period')

        # tabs for mag, flux, period
        tabs = Tabs(tabs=[p2, p1, p3], width=width, height=height, sizing_mode='fixed')
    else:
        # tabs for mag, flux
        tabs = Tabs(tabs=[p2, p1], width=width, height=height + 90, sizing_mode='fixed')
    return bokeh_embed.json_item(tabs)
コード例 #2
0
ファイル: plot.py プロジェクト: alucab/skyportal
def photometry_plot(obj_id, user, width=600, height=300, device="browser"):
    """Create object photometry scatter plot.

    Parameters
    ----------
    obj_id : str
        ID of Obj to be plotted.

    Returns
    -------
    dict
        Returns Bokeh JSON embedding for the desired plot.
    """

    data = pd.read_sql(
        DBSession().query(
            Photometry,
            Telescope.nickname.label("telescope"),
            Instrument.name.label("instrument"),
        ).join(Instrument, Instrument.id == Photometry.instrument_id).join(
            Telescope, Telescope.id == Instrument.telescope_id).filter(
                Photometry.obj_id == obj_id).filter(
                    Photometry.groups.any(
                        Group.id.in_([g.id for g in user.accessible_groups
                                      ]))).statement,
        DBSession().bind,
    )

    if data.empty:
        return None, None, None

    # get spectra to annotate on phot plots
    spectra = (Spectrum.query_records_accessible_by(user).filter(
        Spectrum.obj_id == obj_id).all())

    data['color'] = [get_color(f) for f in data['filter']]

    labels = []
    for i, datarow in data.iterrows():
        label = f'{datarow["instrument"]}/{datarow["filter"]}'
        if datarow['origin'] is not None:
            label += f'/{datarow["origin"]}'
        labels.append(label)

    data['label'] = labels
    data['zp'] = PHOT_ZP
    data['magsys'] = 'ab'
    data['alpha'] = 1.0
    data['lim_mag'] = (
        -2.5 * np.log10(data['fluxerr'] * PHOT_DETECTION_THRESHOLD) +
        data['zp'])

    # Passing a dictionary to a bokeh datasource causes the frontend to die,
    # deleting the dictionary column fixes that
    del data['original_user_data']

    # keep track of things that are only upper limits
    data['hasflux'] = ~data['flux'].isna()

    # calculate the magnitudes - a photometry point is considered "significant"
    # or "detected" (and thus can be represented by a magnitude) if its snr
    # is above PHOT_DETECTION_THRESHOLD
    obsind = data['hasflux'] & (data['flux'].fillna(0.0) / data['fluxerr'] >=
                                PHOT_DETECTION_THRESHOLD)
    data.loc[~obsind, 'mag'] = None
    data.loc[obsind, 'mag'] = -2.5 * np.log10(data[obsind]['flux']) + PHOT_ZP

    # calculate the magnitude errors using standard error propagation formulae
    # https://en.wikipedia.org/wiki/Propagation_of_uncertainty#Example_formulae
    data.loc[~obsind, 'magerr'] = None
    coeff = 2.5 / np.log(10)
    magerrs = np.abs(coeff * data[obsind]['fluxerr'] / data[obsind]['flux'])
    data.loc[obsind, 'magerr'] = magerrs
    data['obs'] = obsind
    data['stacked'] = False

    split = data.groupby('label', sort=False)

    finite = np.isfinite(data['flux'])
    fdata = data[finite]
    lower = np.min(fdata['flux']) * 0.95
    upper = np.max(fdata['flux']) * 1.05

    active_drag = None if "mobile" in device or "tablet" in device else "box_zoom"
    tools = ('box_zoom,pan,reset' if "mobile" in device or "tablet" in device
             else "box_zoom,wheel_zoom,pan,reset,save")

    plot = figure(
        aspect_ratio=2.0 if device == "mobile_landscape" else 1.5,
        sizing_mode='scale_both',
        active_drag=active_drag,
        tools=tools,
        toolbar_location='above',
        toolbar_sticky=True,
        y_range=(lower, upper),
        min_border_right=16,
    )
    imhover = HoverTool(tooltips=tooltip_format)
    imhover.renderers = []
    plot.add_tools(imhover)

    model_dict = {}

    for i, (label, sdf) in enumerate(split):

        # for the flux plot, we only show things that have a flux value
        df = sdf[sdf['hasflux']]

        key = f'obs{i}'
        model_dict[key] = plot.scatter(
            x='mjd',
            y='flux',
            color='color',
            marker='circle',
            fill_color='color',
            alpha='alpha',
            source=ColumnDataSource(df),
        )

        imhover.renderers.append(model_dict[key])

        key = f'bin{i}'
        model_dict[key] = plot.scatter(
            x='mjd',
            y='flux',
            color='color',
            marker='circle',
            fill_color='color',
            source=ColumnDataSource(data=dict(
                mjd=[],
                flux=[],
                fluxerr=[],
                filter=[],
                color=[],
                lim_mag=[],
                mag=[],
                magerr=[],
                stacked=[],
                instrument=[],
            )),
        )

        imhover.renderers.append(model_dict[key])

        key = 'obserr' + str(i)
        y_err_x = []
        y_err_y = []

        for d, ro in df.iterrows():
            px = ro['mjd']
            py = ro['flux']
            err = ro['fluxerr']

            y_err_x.append((px, px))
            y_err_y.append((py - err, py + err))

        model_dict[key] = plot.multi_line(
            xs='xs',
            ys='ys',
            color='color',
            alpha='alpha',
            source=ColumnDataSource(data=dict(xs=y_err_x,
                                              ys=y_err_y,
                                              color=df['color'],
                                              alpha=[1.0] * len(df))),
        )

        key = f'binerr{i}'
        model_dict[key] = plot.multi_line(
            xs='xs',
            ys='ys',
            color='color',
            source=ColumnDataSource(data=dict(xs=[], ys=[], color=[])),
        )

    plot.xaxis.axis_label = 'MJD'
    if device == "mobile_portrait":
        plot.xaxis.ticker.desired_num_ticks = 5
    plot.yaxis.axis_label = 'Flux (μJy)'
    plot.toolbar.logo = None

    colors_labels = data[['color', 'label']].drop_duplicates()

    toggle = CheckboxWithLegendGroup(
        labels=colors_labels.label.tolist(),
        active=list(range(len(colors_labels))),
        colors=colors_labels.color.tolist(),
        width=width // 5,
        inline=True if "tablet" in device else False,
    )

    # TODO replace `eval` with Namespaces
    # https://github.com/bokeh/bokeh/pull/6340
    toggle.js_on_click(
        CustomJS(
            args={
                'toggle': toggle,
                **model_dict
            },
            code=open(
                os.path.join(os.path.dirname(__file__), '../static/js/plotjs',
                             'togglef.js')).read(),
        ))

    slider = Slider(
        start=0.0,
        end=15.0,
        value=0.0,
        step=1.0,
        title='Binsize (days)',
        max_width=350,
        margin=(4, 10, 0, 10),
    )

    callback = CustomJS(
        args={
            'slider': slider,
            'toggle': toggle,
            **model_dict
        },
        code=open(
            os.path.join(os.path.dirname(__file__),
                         '../static/js/plotjs', 'stackf.js')).read().replace(
                             'default_zp', str(PHOT_ZP)).replace(
                                 'detect_thresh',
                                 str(PHOT_DETECTION_THRESHOLD)),
    )

    slider.js_on_change('value', callback)

    # Mark the first and last detections
    detection_dates = data[data['hasflux']]['mjd']
    if len(detection_dates) > 0:
        first = round(detection_dates.min(), 6)
        last = round(detection_dates.max(), 6)
        first_color = "#34b4eb"
        last_color = "#8992f5"
        midpoint = (upper + lower) / 2
        line_top = 5 * upper - 4 * midpoint
        line_bottom = 5 * lower - 4 * midpoint
        y = np.linspace(line_bottom, line_top, num=5000)
        first_r = plot.line(
            x=np.full(5000, first),
            y=y,
            line_alpha=0.5,
            line_color=first_color,
            line_width=2,
        )
        plot.add_tools(
            HoverTool(
                tooltips=[("First detection", f'{first}')],
                renderers=[first_r],
            ))
        last_r = plot.line(
            x=np.full(5000, last),
            y=y,
            line_alpha=0.5,
            line_color=last_color,
            line_width=2,
        )
        plot.add_tools(
            HoverTool(
                tooltips=[("Last detection", f'{last}')],
                renderers=[last_r],
            ))

    # Mark when spectra were taken
    annotate_spec(plot, spectra, lower, upper)

    plot_layout = (column(plot, toggle) if "mobile" in device
                   or "tablet" in device else row(plot, toggle))
    layout = column(slider,
                    plot_layout,
                    sizing_mode='scale_width',
                    width=width)

    p1 = Panel(child=layout, title='Flux')

    # now make the mag light curve
    ymax = (np.nanmax((
        np.nanmax(data.loc[obsind, 'mag']) if any(obsind) else np.nan,
        np.nanmax(data.loc[~obsind, 'lim_mag']) if any(~obsind) else np.nan,
    )) + 0.1)
    ymin = (np.nanmin((
        np.nanmin(data.loc[obsind, 'mag']) if any(obsind) else np.nan,
        np.nanmin(data.loc[~obsind, 'lim_mag']) if any(~obsind) else np.nan,
    )) - 0.1)

    xmin = data['mjd'].min() - 2
    xmax = data['mjd'].max() + 2

    plot = figure(
        aspect_ratio=2.0 if device == "mobile_landscape" else 1.5,
        sizing_mode='scale_both',
        width=width,
        active_drag=active_drag,
        tools=tools,
        y_range=(ymax, ymin),
        x_range=(xmin, xmax),
        toolbar_location='above',
        toolbar_sticky=True,
        x_axis_location='above',
    )

    # Mark the first and last detections again
    detection_dates = data[obsind]['mjd']
    if len(detection_dates) > 0:
        first = round(detection_dates.min(), 6)
        last = round(detection_dates.max(), 6)
        midpoint = (ymax + ymin) / 2
        line_top = 5 * ymax - 4 * midpoint
        line_bottom = 5 * ymin - 4 * midpoint
        y = np.linspace(line_bottom, line_top, num=5000)
        first_r = plot.line(
            x=np.full(5000, first),
            y=y,
            line_alpha=0.5,
            line_color=first_color,
            line_width=2,
        )
        plot.add_tools(
            HoverTool(
                tooltips=[("First detection", f'{first}')],
                renderers=[first_r],
            ))
        last_r = plot.line(
            x=np.full(5000, last),
            y=y,
            line_alpha=0.5,
            line_color=last_color,
            line_width=2,
        )
        plot.add_tools(
            HoverTool(
                tooltips=[("Last detection", f'{last}')],
                renderers=[last_r],
                point_policy='follow_mouse',
            ))

    # Mark when spectra were taken
    annotate_spec(plot, spectra, ymax, ymin)

    imhover = HoverTool(tooltips=tooltip_format)
    imhover.renderers = []
    plot.add_tools(imhover)

    model_dict = {}

    for i, (label, df) in enumerate(split):

        key = f'obs{i}'
        model_dict[key] = plot.scatter(
            x='mjd',
            y='mag',
            color='color',
            marker='circle',
            fill_color='color',
            alpha='alpha',
            source=ColumnDataSource(df[df['obs']]),
        )

        imhover.renderers.append(model_dict[key])

        unobs_source = df[~df['obs']].copy()
        unobs_source.loc[:, 'alpha'] = 0.8

        key = f'unobs{i}'
        model_dict[key] = plot.scatter(
            x='mjd',
            y='lim_mag',
            color='color',
            marker='inverted_triangle',
            fill_color='white',
            line_color='color',
            alpha='alpha',
            source=ColumnDataSource(unobs_source),
        )

        imhover.renderers.append(model_dict[key])

        key = f'bin{i}'
        model_dict[key] = plot.scatter(
            x='mjd',
            y='mag',
            color='color',
            marker='circle',
            fill_color='color',
            source=ColumnDataSource(data=dict(
                mjd=[],
                flux=[],
                fluxerr=[],
                filter=[],
                color=[],
                lim_mag=[],
                mag=[],
                magerr=[],
                instrument=[],
                stacked=[],
            )),
        )

        imhover.renderers.append(model_dict[key])

        key = 'obserr' + str(i)
        y_err_x = []
        y_err_y = []

        for d, ro in df[df['obs']].iterrows():
            px = ro['mjd']
            py = ro['mag']
            err = ro['magerr']

            y_err_x.append((px, px))
            y_err_y.append((py - err, py + err))

        model_dict[key] = plot.multi_line(
            xs='xs',
            ys='ys',
            color='color',
            alpha='alpha',
            source=ColumnDataSource(data=dict(
                xs=y_err_x,
                ys=y_err_y,
                color=df[df['obs']]['color'],
                alpha=[1.0] * len(df[df['obs']]),
            )),
        )

        key = f'binerr{i}'
        model_dict[key] = plot.multi_line(
            xs='xs',
            ys='ys',
            color='color',
            source=ColumnDataSource(data=dict(xs=[], ys=[], color=[])),
        )

        key = f'unobsbin{i}'
        model_dict[key] = plot.scatter(
            x='mjd',
            y='lim_mag',
            color='color',
            marker='inverted_triangle',
            fill_color='white',
            line_color='color',
            alpha=0.8,
            source=ColumnDataSource(data=dict(
                mjd=[],
                flux=[],
                fluxerr=[],
                filter=[],
                color=[],
                lim_mag=[],
                mag=[],
                magerr=[],
                instrument=[],
                stacked=[],
            )),
        )
        imhover.renderers.append(model_dict[key])

        key = f'all{i}'
        model_dict[key] = ColumnDataSource(df)

        key = f'bold{i}'
        model_dict[key] = ColumnDataSource(df[[
            'mjd',
            'flux',
            'fluxerr',
            'mag',
            'magerr',
            'filter',
            'zp',
            'magsys',
            'lim_mag',
            'stacked',
        ]])

    plot.xaxis.axis_label = 'MJD'
    plot.yaxis.axis_label = 'AB mag'
    plot.toolbar.logo = None

    obj = DBSession().query(Obj).get(obj_id)
    if obj.dm is not None:
        plot.extra_y_ranges = {
            "Absolute Mag": Range1d(start=ymax - obj.dm, end=ymin - obj.dm)
        }
        plot.add_layout(
            LinearAxis(y_range_name="Absolute Mag", axis_label="m - DM"),
            'right')

    now = Time.now().mjd
    plot.extra_x_ranges = {
        "Days Ago": Range1d(start=now - xmin, end=now - xmax)
    }
    plot.add_layout(LinearAxis(x_range_name="Days Ago", axis_label="Days Ago"),
                    'below')

    colors_labels = data[['color', 'label']].drop_duplicates()

    toggle = CheckboxWithLegendGroup(
        labels=colors_labels.label.tolist(),
        active=list(range(len(colors_labels))),
        colors=colors_labels.color.tolist(),
        width=width // 5,
        inline=True if "tablet" in device else False,
    )

    # TODO replace `eval` with Namespaces
    # https://github.com/bokeh/bokeh/pull/6340
    toggle.js_on_click(
        CustomJS(
            args={
                'toggle': toggle,
                **model_dict
            },
            code=open(
                os.path.join(os.path.dirname(__file__), '../static/js/plotjs',
                             'togglem.js')).read(),
        ))

    slider = Slider(
        start=0.0,
        end=15.0,
        value=0.0,
        step=1.0,
        title='Binsize (days)',
        max_width=350,
        margin=(4, 10, 0, 10),
    )

    button = Button(label="Export Bold Light Curve to CSV")
    button.js_on_click(
        CustomJS(
            args={
                'slider': slider,
                'toggle': toggle,
                **model_dict
            },
            code=open(
                os.path.join(os.path.dirname(__file__), '../static/js/plotjs',
                             "download.js")).read().replace('objname',
                                                            obj_id).replace(
                                                                'default_zp',
                                                                str(PHOT_ZP)),
        ))

    # Don't need to expose CSV download on mobile
    top_layout = (slider if "mobile" in device or "tablet" in device else row(
        slider, button))

    callback = CustomJS(
        args={
            'slider': slider,
            'toggle': toggle,
            **model_dict
        },
        code=open(
            os.path.join(os.path.dirname(__file__),
                         '../static/js/plotjs', 'stackm.js')).read().replace(
                             'default_zp', str(PHOT_ZP)).replace(
                                 'detect_thresh',
                                 str(PHOT_DETECTION_THRESHOLD)),
    )
    slider.js_on_change('value', callback)
    plot_layout = (column(plot, toggle) if "mobile" in device
                   or "tablet" in device else row(plot, toggle))
    layout = column(top_layout,
                    plot_layout,
                    sizing_mode='scale_width',
                    width=width)

    p2 = Panel(child=layout, title='Mag')

    # now make period plot

    # get periods from annotations
    annotation_list = obj.get_annotations_readable_by(user)
    period_labels = []
    period_list = []
    for an in annotation_list:
        if 'period' in an.data:
            period_list.append(an.data['period'])
            period_labels.append(an.origin + ": %.9f" % an.data['period'])

    if len(period_list) > 0:
        period = period_list[0]
    else:
        period = None

    # don't generate if no period annotated
    if period is not None:

        # bokeh figure for period plotting
        period_plot = figure(
            aspect_ratio=1.5,
            sizing_mode='scale_both',
            active_drag='box_zoom',
            tools='box_zoom,wheel_zoom,pan,reset,save',
            y_range=(ymax, ymin),
            x_range=(-0.1, 1.1),  # initially one phase
            toolbar_location='above',
            toolbar_sticky=False,
            x_axis_location='below',
        )

        # axis labels
        period_plot.xaxis.axis_label = 'phase'
        period_plot.yaxis.axis_label = 'mag'
        period_plot.toolbar.logo = None

        # do we have a distance modulus (dm)?
        obj = DBSession().query(Obj).get(obj_id)
        if obj.dm is not None:
            period_plot.extra_y_ranges = {
                "Absolute Mag": Range1d(start=ymax - obj.dm, end=ymin - obj.dm)
            }
            period_plot.add_layout(
                LinearAxis(y_range_name="Absolute Mag", axis_label="m - DM"),
                'right')

        # initiate hover tool
        period_imhover = HoverTool(tooltips=tooltip_format)
        period_imhover.renderers = []
        period_plot.add_tools(period_imhover)

        # initiate period radio buttons
        period_selection = RadioGroup(labels=period_labels, active=0)

        phase_selection = RadioGroup(labels=["One phase", "Two phases"],
                                     active=0)

        # store all the plot data
        period_model_dict = {}

        # iterate over each filter
        for i, (label, df) in enumerate(split):

            # fold x-axis on period in days
            df['mjd_folda'] = (df['mjd'] % period) / period
            df['mjd_foldb'] = df['mjd_folda'] + 1.0

            # phase plotting
            for ph in ['a', 'b']:
                key = 'fold' + ph + f'{i}'
                period_model_dict[key] = period_plot.scatter(
                    x='mjd_fold' + ph,
                    y='mag',
                    color='color',
                    marker='circle',
                    fill_color='color',
                    alpha='alpha',
                    visible=('a' in ph),
                    source=ColumnDataSource(
                        df[df['obs']]),  # only visible data
                )
                # add to hover tool
                period_imhover.renderers.append(period_model_dict[key])

                # errorbars for phases
                key = 'fold' + ph + f'err{i}'
                y_err_x = []
                y_err_y = []

                # get each visible error value
                for d, ro in df[df['obs']].iterrows():
                    px = ro['mjd_fold' + ph]
                    py = ro['mag']
                    err = ro['magerr']
                    # set up error tuples
                    y_err_x.append((px, px))
                    y_err_y.append((py - err, py + err))
                # plot phase errors
                period_model_dict[key] = period_plot.multi_line(
                    xs='xs',
                    ys='ys',
                    color='color',
                    alpha='alpha',
                    visible=('a' in ph),
                    source=ColumnDataSource(data=dict(
                        xs=y_err_x,
                        ys=y_err_y,
                        color=df[df['obs']]['color'],
                        alpha=[1.0] * len(df[df['obs']]),
                    )),
                )

        # toggle for folded photometry
        period_toggle = CheckboxWithLegendGroup(
            labels=colors_labels.label.tolist(),
            active=list(range(len(colors_labels))),
            colors=colors_labels.color.tolist(),
            width=width // 5,
        )
        # use javascript to perform toggling on click
        # TODO replace `eval` with Namespaces
        # https://github.com/bokeh/bokeh/pull/6340
        period_toggle.js_on_click(
            CustomJS(
                args={
                    'toggle': period_toggle,
                    'numphases': phase_selection,
                    'p': period_plot,
                    **period_model_dict,
                },
                code=open(
                    os.path.join(os.path.dirname(__file__),
                                 '../static/js/plotjs', 'togglep.js')).read(),
            ))

        # set up period adjustment text box
        period_title = Div(text="Period (days): ")
        period_textinput = TextInput(
            value=str(period if period is not None else 0.0))
        period_textinput.js_on_change(
            'value',
            CustomJS(
                args={
                    'textinput': period_textinput,
                    'toggle': period_toggle,
                    'numphases': phase_selection,
                    'p': period_plot,
                    **period_model_dict,
                },
                code=open(
                    os.path.join(os.path.dirname(__file__),
                                 '../static/js/plotjs',
                                 'foldphase.js')).read(),
            ),
        )
        # a way to modify the period
        period_double_button = Button(label="*2")
        period_double_button.js_on_click(
            CustomJS(
                args={'textinput': period_textinput},
                code="""
                const period = parseFloat(textinput.value);
                textinput.value = parseFloat(2.*period).toFixed(9);
                """,
            ))
        period_halve_button = Button(label="/2")
        period_halve_button.js_on_click(
            CustomJS(
                args={'textinput': period_textinput},
                code="""
                        const period = parseFloat(textinput.value);
                        textinput.value = parseFloat(period/2.).toFixed(9);
                        """,
            ))
        # a way to select the period
        period_selection.js_on_click(
            CustomJS(
                args={
                    'textinput': period_textinput,
                    'periods': period_list
                },
                code="""
                textinput.value = parseFloat(periods[this.active]).toFixed(9);
                """,
            ))
        phase_selection.js_on_click(
            CustomJS(
                args={
                    'textinput': period_textinput,
                    'toggle': period_toggle,
                    'numphases': phase_selection,
                    'p': period_plot,
                    **period_model_dict,
                },
                code=open(
                    os.path.join(os.path.dirname(__file__),
                                 '../static/js/plotjs',
                                 'foldphase.js')).read(),
            ))

        # layout

        period_column = column(
            period_toggle,
            period_title,
            period_textinput,
            period_selection,
            row(period_double_button, period_halve_button, width=180),
            phase_selection,
            width=180,
        )

        period_layout = column(
            row(period_plot, period_column),
            sizing_mode='scale_width',
            width=width,
        )

        # Period panel
        p3 = Panel(child=period_layout, title='Period')
        # tabs for mag, flux, period
        tabs = Tabs(tabs=[p2, p1, p3],
                    width=width,
                    height=height,
                    sizing_mode='fixed')
    else:
        # tabs for mag, flux
        tabs = Tabs(tabs=[p2, p1],
                    width=width,
                    height=height,
                    sizing_mode='fixed')
    return bokeh_embed.json_item(tabs)