Exemple #1
0
def gpu(doc):
    fig = figure(title="GPU Utilization",
                 sizing_mode="stretch_both",
                 x_range=[0, 100])

    def get_utilization():
        return [
            pynvml.nvmlDeviceGetUtilizationRates(gpu_handles[i]).gpu
            for i in range(ngpus)
        ]

    gpu = get_utilization()
    y = list(range(len(gpu)))
    source = ColumnDataSource({"right": y, "gpu": gpu})
    mapper = LinearColorMapper(palette=all_palettes["RdYlBu"][4],
                               low=0,
                               high=100)

    fig.hbar(
        source=source,
        y="right",
        right="gpu",
        height=0.8,
        color={
            "field": "gpu",
            "transform": mapper
        },
    )

    fig.toolbar_location = None

    doc.title = "GPU Utilization [%]"
    doc.add_root(fig)

    def cb():
        source.data.update({"gpu": get_utilization()})

    doc.add_periodic_callback(cb, 200)
def gpu_mem(doc):

    def get_mem():
        return [pynvml.nvmlDeviceGetMemoryInfo(
            handle).used for handle in gpu_handles]

    def get_total():
        return pynvml.nvmlDeviceGetMemoryInfo(gpu_handles[0]).total

    fig = figure(title="GPU Memory",
                 sizing_mode="stretch_both", x_range=[0, get_total()])

    gpu = get_mem()

    y = list(range(len(gpu)))
    source = ColumnDataSource({"right": y, "gpu": gpu})
    mapper = LinearColorMapper(
        palette=all_palettes['RdYlBu'][8], low=0, high=get_total())

    fig.hbar(
        source=source, y="right", right='gpu', height=0.8, color={"field": "gpu", "transform": mapper}
    )
    fig.xaxis[0].formatter = NumeralTickFormatter(format="0.0 b")
    fig.xaxis.major_label_orientation = -math.pi / 12

    fig.toolbar_location = None

    doc.title = "GPU Memory"
    doc.add_root(fig)

    def cb():
        mem = get_mem()
        source.data.update({"gpu": mem})
        fig.title.text = "GPU Memory: {}".format(format_bytes(sum(mem)))

    doc.add_periodic_callback(cb, 200)
multilinesource1 = ColumnDataSource({
    'xs': xs,
    'ys': ys,
    'sentiment': average_sentiment_array
})
multilinesource2 = ColumnDataSource({
    'xs':
    xs,
    'ys': [[0, -1 * 60] for i in range(0, len(article_2021))],
    'sentiment':
    average_sentiment_array
})

colormap = cm.get_cmap("RdYlGn")
bokehpalette = [mpl.colors.rgb2hex(m) for m in colormap(np.arange(colormap.N))]
mapper = LinearColorMapper(palette=bokehpalette, low=-1.0, high=1.0)

tools = 'tap, wheel_zoom, pan, reset, xwheel_pan'
p = figure(tools=tools,
           plot_width=1000,
           plot_height=800,
           title="News articles clustering based on bag of words overlap",
           y_range=(-400, 400),
           x_range=(0, 800))

p.line(x='x', y=0, line_width=2, source=linesource)
p.circle(x='x',
         y=0,
         radius=5,
         color={
             'field': 'sentiment',
Exemple #4
0
def test_LinearColorMapper():
    mapper = LinearColorMapper()
    check_properties_existence(
        mapper,
        ["palette", "low", "high", "low_color", "high_color", "nan_color"],
    )
Exemple #5
0
    def create(cls):
        # ==============================================================================
        # creates initial layout and data
        # ==============================================================================
        obj = cls()

        # initialize data source
        obj.source = ColumnDataSource(data=dict(z=[]))

        # initialize controls
        # slider controlling stepsize of the solver
        obj.value2 = Slider(title="value2",
                            name='value2',
                            value=1,
                            start=-1,
                            end=+1,
                            step=.1)
        # slider controlling initial value of the ode
        obj.value1 = Slider(title="value1",
                            name='value1',
                            value=0,
                            start=-1,
                            end=+1,
                            step=.1)

        # initialize plot
        toolset = "crosshair,pan,reset,resize,wheel_zoom,box_zoom"
        # Generate a figure container
        plot = figure(
            title_text_font_size="12pt",
            plot_height=400,
            plot_width=400,
            tools=toolset,
            # title=obj.text.value,
            title="somestuff",
            x_range=[-1, 1],
            y_range=[-1, 1])
        # Plot the numerical solution by the x,t values in the source property

        plot.image(
            image='z',
            x=-1,
            y=-1,
            dw=2,
            dh=2,
            #palette="Spectral11",
            color_mapper=LinearColorMapper(palette=svg_palette_jet,
                                           low=-2,
                                           high=2),
            source=obj.source)

        obj.plot = plot
        # calculate data
        obj.update_data()

        # lists all the controls in our app
        obj.controls = VBoxForm(children=[obj.value1, obj.value2])

        # make layout
        obj.children.append(obj.plot)
        obj.children.append(obj.controls)

        # don't forget to return!
        return obj
Exemple #6
0
    def plot_moll(self):
        # Plot the map
        plot_moll = figure(
            aspect_ratio=2,
            toolbar_location=None,
            x_range=(-2, 2),
            y_range=(-1, 1),
            id="plot_moll",
            name="plot_moll",
        )
        plot_moll.axis.visible = False
        plot_moll.grid.visible = False
        plot_moll.outline_line_color = None
        color_mapper = LinearColorMapper(
            palette="Plasma256",
            nan_color="white",
            low=self.vmin_m,
            high=self.vmax_m,
        )
        plot_moll.image(
            image="moll",
            x=-2,
            y=-1,
            dw=4,
            dh=2,
            color_mapper=color_mapper,
            source=self.source_moll,
        )
        plot_moll.toolbar.active_drag = None
        plot_moll.toolbar.active_scroll = None
        plot_moll.toolbar.active_tap = None

        # Plot the lat/lon grid
        lat_lines = get_moll_latitude_lines()
        lon_lines = get_moll_longitude_lines()
        for x, y in lat_lines:
            plot_moll.line(
                x / np.sqrt(2),
                y / np.sqrt(2),
                line_width=1,
                color="black",
                alpha=0.25,
            )
        for x, y in lon_lines:
            plot_moll.line(
                x / np.sqrt(2),
                y / np.sqrt(2),
                line_width=1,
                color="black",
                alpha=0.25,
            )
        self.add_border(plot_moll, "moll")

        # Interaction: show spectra at different points as mouse moves
        mouse_move_callback = CustomJS(
            args={
                "source_spec": self.source_spec,
                "moll": self.moll,
                "spec": self.spec,
                "npix_m": self.npix_m,
                "nc": self.nc,
                "nws": self.nws,
            },
            code="""
                var x = cb_obj["x"];
                var y = cb_obj["y"];

                if ((x > - 2) && (x < 2) && (y > -1) && (y < 1)) {

                    // Image index below cursor
                    var i = Math.floor(0.25 * (x + 2) * npix_m);
                    var j = Math.floor(0.5 * (y + 1) * npix_m);

                    // Compute weighted spectrum
                    if (!isNaN(moll[0][j][i])) {
                        var local_spec = new Array(nws).fill(0);
                        for (var k = 0; k < nc; k++) {
                            var weight = moll[k][j][i];
                            for (var l = 0; l < nws; l++) {
                                local_spec[l] += weight * spec[k][l]
                            }
                        }

                        // Update the plot
                        source_spec.data["spec"] = local_spec;
                        source_spec.change.emit();
                    }
                }
                """,
        )
        plot_moll.js_on_event(MouseMove, mouse_move_callback)

        # Interaction: Cycle through wavelength as mouse wheel moves
        mouse_wheel_callback = CustomJS(
            args={
                "source_moll": self.source_moll,
                "source_index": self.source_index,
                "spec_vline": self.spec_vline,
                "moll": self.moll,
                "spec": self.spec,
                "wavs": self.wavs,
                "npix_m": self.npix_m,
                "nc": self.nc,
                "nws": self.nws,
            },
            code="""
                // Update the current wavelength index
                var delta = Math.floor(cb_obj["delta"]);
                var l = source_index.data["l"][0];
                l += delta;
                if (l < 0) l = 0;
                if (l > nws - 1) l = nws - 1;
                source_index.data["l"][0] = l;
                source_index.change.emit();
                spec_vline.location = wavs[l];

                // Update the map
                var local_moll = new Array(npix_m).fill(0).map(() => new Array(npix_m).fill(0));
                for (var k = 0; k < nc; k++) {
                    var weight = spec[k][l];
                    for (var i = 0; i < npix_m; i++) {
                        for (var j = 0; j < npix_m; j++) {
                            local_moll[j][i] += weight * moll[k][j][i];
                        }
                    }
                }
                source_moll.data["moll"][0] = local_moll;
                source_moll.change.emit();
                """,
        )
        plot_moll.js_on_event(MouseWheel, mouse_wheel_callback)

        mouse_enter_callback = CustomJS(code="""
            DISABLE_WHEEL= true;
            """)
        plot_moll.js_on_event(MouseEnter, mouse_enter_callback)

        mouse_leave_callback = CustomJS(code="""
            DISABLE_WHEEL = false;
            """)
        plot_moll.js_on_event(MouseLeave, mouse_leave_callback)

        return plot_moll
Exemple #7
0
def pci(doc):

    # Use device-0 to get "upper bound"
    pci_gen = pynvml.nvmlDeviceGetMaxPcieLinkGeneration(gpu_handles[0])
    pci_width = pynvml.nvmlDeviceGetMaxPcieLinkWidth(gpu_handles[0])
    pci_bw = {
        # Keys = PCIe-Generation, Values = Max PCIe Lane BW (per direction)
        # [Note: Using specs at https://en.wikipedia.org/wiki/PCI_Express]
        1: (250.0 / 1024.0),
        2: (500.0 / 1024.0),
        3: (985.0 / 1024.0),
        4: (1969.0 / 1024.0),
        5: (3938.0 / 1024.0),
        6: (7877.0 / 1024.0),
    }
    # Max PCIe Throughput = (BW-per-lane / Width)
    max_rxtx_tp = pci_width * pci_bw[pci_gen]

    pci_tx = [
        pynvml.nvmlDeviceGetPcieThroughput(gpu_handles[i],
                                           pynvml.NVML_PCIE_UTIL_TX_BYTES) /
        (1024.0 * 1024.0)  # Convert KB/s -> GB/s
        for i in range(ngpus)
    ]

    pci_rx = [
        pynvml.nvmlDeviceGetPcieThroughput(gpu_handles[i],
                                           pynvml.NVML_PCIE_UTIL_RX_BYTES) /
        (1024.0 * 1024.0)  # Convert KB/s -> GB/s
        for i in range(ngpus)
    ]

    left = list(range(ngpus))
    right = [l + 0.8 for l in left]
    source = ColumnDataSource({
        "left": left,
        "right": right,
        "pci-tx": pci_tx,
        "pci-rx": pci_rx
    })
    mapper = LinearColorMapper(palette=all_palettes["RdYlBu"][4],
                               low=0,
                               high=max_rxtx_tp)

    tx_fig = figure(title="TX Bytes [GB/s]",
                    sizing_mode="stretch_both",
                    y_range=[0, max_rxtx_tp])
    tx_fig.quad(
        source=source,
        left="left",
        right="right",
        bottom=0,
        top="pci-tx",
        color={
            "field": "pci-tx",
            "transform": mapper
        },
    )
    tx_fig.toolbar_location = None

    rx_fig = figure(title="RX Bytes [GB/s]",
                    sizing_mode="stretch_both",
                    y_range=[0, max_rxtx_tp])
    rx_fig.quad(
        source=source,
        left="left",
        right="right",
        bottom=0,
        top="pci-rx",
        color={
            "field": "pci-rx",
            "transform": mapper
        },
    )
    rx_fig.toolbar_location = None

    doc.title = "PCI Throughput"
    doc.add_root(column(tx_fig, rx_fig, sizing_mode="stretch_both"))

    def cb():
        src_dict = {}
        src_dict["pci-tx"] = [
            pynvml.nvmlDeviceGetPcieThroughput(
                gpu_handles[i], pynvml.NVML_PCIE_UTIL_TX_BYTES) /
            (1024.0 * 1024.0)  # Convert KB/s -> GB/s
            for i in range(ngpus)
        ]
        src_dict["pci-rx"] = [
            pynvml.nvmlDeviceGetPcieThroughput(
                gpu_handles[i], pynvml.NVML_PCIE_UTIL_RX_BYTES) /
            (1024.0 * 1024.0)  # Convert KB/s -> GB/s
            for i in range(ngpus)
        ]
        source.data.update(src_dict)

    doc.add_periodic_callback(cb, 200)
Exemple #8
0
def create_summary_heatmap(df, std_dev_sem_columns):
    df = df.set_index('Tool').iloc[::-1]
    tools = list(df.index)
    metrics = list(reversed(list(df.columns)))

    # unweighted columns should be at the right side
    for column in std_dev_sem_columns:
        metrics.append(metrics.pop(metrics.index(column)))

    df = df[metrics]

    UNWEIGHTED_NUMBER = 1.10001
    WEIGHTING_COLUMN = 'rate_extended'
    DEFAULT_TOOL_HEIGHT = 10
    COLORBAR_HEIGHT = 150
    ALPHA_COLOR=0.85

    df.columns.name = 'Metrics'
    df = pd.DataFrame(df.stack(), columns=['rate']).reset_index()
    df['rate'] = df['rate'].map('{:,.5f}'.format)
    df[WEIGHTING_COLUMN] = df['rate']

    for column in std_dev_sem_columns:
        df.loc[df.Metrics == column, WEIGHTING_COLUMN] = UNWEIGHTED_NUMBER

    mapper = LinearColorMapper(palette=HEATMAP_COLORS, low=0, high=UNWEIGHTED_NUMBER)
    source = ColumnDataSource(df)

    p = figure(x_range=metrics, y_range=tools,
               x_axis_location="above", plot_height=len(tools) * DEFAULT_TOOL_HEIGHT + COLORBAR_HEIGHT,
               tools="hover,save,box_zoom,reset,wheel_zoom", toolbar_location='below')

    p = _set_default_figure_properties(p, "Metrics", "Tools")
    p.xaxis.major_label_orientation = pi / 2.5

    p.rect(x="Metrics", y="Tool",
           width=1,
           height=1,
           source=source,
           alpha=ALPHA_COLOR,
           fill_color={'field': WEIGHTING_COLUMN, 'transform': mapper},
           line_color="black")

    glyph = Text(x="Metrics", y="Tool", text_align="center",
                 text_font_size="10pt",
                 text_baseline="middle", text="rate", text_color="black")
    p.add_glyph(source, glyph)

    tickFormatter = FuncTickFormatter(code="""
    if(tick==1){
        return tick + " (good)"
    } else if(tick==0){
        return tick + " (bad)"
    } else if(tick==1.1){
        return "Unweighted"
    } else {
        return tick.toLocaleString(
  undefined, // use a string like 'en-US' to override browser locale
  { minimumFractionDigits: 2 }
);
    }
    """)

    color_bar = ColorBar(color_mapper=mapper,
                         major_label_text_font_size="12pt",
                         ticker=BasicTicker(desired_num_ticks=len(HEATMAP_COLORS)),
                         scale_alpha=ALPHA_COLOR,
                         major_label_text_align="right",
                         major_label_text_baseline="middle",
                         bar_line_color="black",
                         formatter=tickFormatter,
                         label_standoff=13,
                         orientation="horizontal",
                         location=(-250, 0))

    p.add_layout(color_bar, 'above')
    p.select_one(HoverTool).tooltips = [
        ('Metric', '@Metrics'),
        ('Tool', '@Tool'),
        ('Value', '@rate'),
    ]
    return [p]
Exemple #9
0
def plot_pca_space(
    x,
    y,
    beta,
    plot_type,
    target_name,
    classification_strings=None,
    plot_both=True,
    output_html=None,
    width=500,
    height=500,
    sizing_mode="stretch_both",
):
    """Plot regression predictions in a 2-component PCA space.

    This function has two plot modes, specified by the `plot_both` flag. If
    `plot_both` == True, this plots side-by-side scatter plots of the target
    variable in 2-D PCA space. The right plot is the post-SGL weighted feature
    matrix and the left plot is the pre-SGL original feature matrix.

    Otherwise this plots only the post-SGL weighted feature space and also
    plots a contour of the regression prediction.

    Parameters
    ----------
    x : numpy.ndarray
        Feature matrix

    y : pandas.Series
        Binary classification target array

    beta : numpy.ndarray
        Regression coefficients

    plot_type : 'regression' or 'classification'
        Type of ML problem

    target_name : string
        The name of the target variable (used in hover tool)

    classification_strings : dict
        Dictionary mapping the categorical numerical target values onto their
        names. If `plot_type` == "regression", this parameter is not used.

    plot_both : boolean, default=True
        If True, plot the PCA in both the original feature space and the
        feature space projected onto the coefficient vector

    output_html : string or None, default=None
        Filename for bokeh html output. If None, figure will not be saved

    width : int, default=500
        Width of each beta plot (in pixels)

    height : int, default=500
        Height of each beta plot (in pixels)

    sizing_mode : string
        One of ("fixed", "stretch_both", "scale_width", "scale_height",
        "scale_both"). Specifies how will the items in the layout resize to
        fill the available space. Default is "stretch_both". For more
        information on the different modes see
        https://bokeh.pydata.org/en/latest/docs/reference/models/layouts.html#bokeh.models.layouts.LayoutDOM
    """
    if plot_type not in ["regression", "classification"]:
        raise ValueError(
            '`plot_type` must be either "classification" or ' '"regression"'
        )

    x_projection = np.outer(x.dot(beta), beta) / (np.linalg.norm(beta) ** 2.0)

    pca_orig = PCA(n_components=2)
    pca_sgl = PCA(n_components=2)

    x2_sgl = pca_sgl.fit_transform(x_projection)
    x2_orig = pca_orig.fit_transform(x)

    if plot_type == "classification":
        cmap = plt.get_cmap("RdBu")
        colors = [to_hex(c) for c in cmap(np.linspace(1, 0, 256))]
    else:
        colors = Cividis256

    color_mapper = LinearColorMapper(palette=colors, low=np.min(y), high=np.max(y))

    color_bar = ColorBar(
        color_mapper=color_mapper,
        ticker=FixedTicker(ticks=np.arange(np.min(y), np.max(y), 5)),
        label_standoff=8,
        border_line_color=None,
        location=(0, 0),
    )

    if plot_type == "classification":
        target = y.copy()
        for k, v in classification_strings.items():
            target[y == k] = v
    else:
        target = y.copy()

    pc_info = {
        "pc0_sgl": x2_sgl[:, 0],
        "pc1_sgl": x2_sgl[:, 1],
        "target": y.values,
        "target_string": target.values,
        "subject_id": target.index,
    }

    ps = [None]

    if plot_both:
        pc_info["pc0_orig"] = x2_orig[:, 0]
        pc_info["pc1_orig"] = x2_orig[:, 1]
        ps = [None] * 2

    tooltips = [("Subject", "@subject_id"), (target_name, "@target_string")]

    source = ColumnDataSource(data=pc_info)
    code = "source.set('selected', cb_data.index);"
    callback = CustomJS(args={"source": source}, code=code)

    if not plot_both:
        ps[0] = figure(
            plot_width=int(width * 1.1), plot_height=height, toolbar_location="right"
        )

        npoints = 200
        dx = np.max(x2_sgl[:, 0]) - np.min(x2_sgl[:, 0])
        xmid = 0.5 * (np.max(x2_sgl[:, 0]) + np.min(x2_sgl[:, 0]))
        xmin = xmid - (dx * 1.1 / 2.0)
        xmax = xmid + (dx * 1.1 / 2.0)

        dy = np.max(x2_sgl[:, 1]) - np.min(x2_sgl[:, 1])
        ymid = 0.5 * (np.max(x2_sgl[:, 1]) + np.min(x2_sgl[:, 1]))
        ymin = ymid - (dy * 1.1 / 2.0)
        ymax = ymid + (dy * 1.1 / 2.0)

        x_subspace = np.linspace(xmin, xmax, npoints)
        y_subspace = np.linspace(ymin, ymax, npoints)
        subspace_pairs = np.array(
            [[p[0], p[1]] for p in itertools.product(x_subspace, y_subspace)]
        )
        bigspace_pairs = (
            pca_sgl.inverse_transform(subspace_pairs) * np.linalg.norm(beta) ** 2.0
        )
        predict_pairs = bigspace_pairs.dot(
            np.divide(
                np.ones_like(beta), beta, out=np.zeros_like(beta), where=beta != 0
            )
        )
        x_grid, _ = np.meshgrid(x_subspace, y_subspace)
        p_grid = predict_pairs.reshape(x_grid.shape, order="F")

        ps[0].image(
            image=[p_grid], x=xmin, y=ymin, dw=dx * 1.1, dh=dy * 1.1, palette=colors
        )

        ps[0].add_layout(color_bar, "right")
        ps[0].x_range = Range1d(xmin, xmax)
        ps[0].y_range = Range1d(ymin, ymax)
    else:
        ps[0] = figure(plot_width=width, plot_height=height, toolbar_location="right")

        if plot_type == "regression":
            ps[0].add_layout(color_bar, "right")

    if plot_type == "regression":
        ps[0].title.text = "Regression in Post-SGL PCA space"
        s0 = ps[0].scatter(
            "pc0_sgl",
            "pc1_sgl",
            source=source,
            size=20,
            fill_color={"field": "target", "transform": color_mapper},
            line_color="white",
            line_width=2.5,
        )
    else:
        ps[0].title.text = "Classification in Post-SGL PCA space"
        s0 = ps[0].scatter(
            "pc0_sgl",
            "pc1_sgl",
            source=source,
            size=20,
            fill_color={"field": "target", "transform": color_mapper},
            line_color="white",
            line_width=2.5,
            legend="target_string",
        )

    hover0 = HoverTool(tooltips=tooltips, callback=callback, renderers=[s0])
    ps[0].add_tools(hover0)

    if plot_both:
        ps[1] = figure(plot_width=width, plot_height=height, toolbar_location="right")

        if plot_type == "regression":
            ps[1].title.text = "Regression in Original PCA space"
            s1 = ps[1].scatter(
                "pc0_orig",
                "pc1_orig",
                source=source,
                size=20,
                fill_color={"field": "target", "transform": color_mapper},
                line_color="white",
                line_width=2.5,
            )
        else:
            ps[1].title.text = "Classification in Original PCA space"
            s1 = ps[1].scatter(
                "pc0_orig",
                "pc1_orig",
                source=source,
                size=20,
                fill_color={"field": "target", "transform": color_mapper},
                line_color="white",
                line_width=2.5,
                legend="target_string",
            )

        hover1 = HoverTool(tooltips=tooltips, callback=callback, renderers=[s1])
        ps[1].add_tools(hover1)

    for plot in ps:
        plot.xaxis.axis_label = "1st Principal Component"
        plot.yaxis.axis_label = "2nd Principal Component"

    if plot_both:
        layout = row(ps[::-1])
    else:
        layout = ps[0]

    layout.sizing_mode = sizing_mode

    if output_html is not None:
        html = file_html(layout, CDN, "my plot")
        with open(op.abspath(output_html), "w") as fp:
            fp.write(html)
    else:
        show(layout)
Exemple #10
0
                 fill_color="colors",
                 fill_alpha=0.8,
                 line_color=None)
plot.add_glyph(source, circles)

label = Label(x=-100,
              y=-10,
              text=str(zip_date_ranges[0][1]),
              text_font_size='70pt',
              text_color='#FFDE8D')
plot.add_layout(label)

color_mapper = LinearColorMapper(
    palette=[
        '#FDE724', '#B2DD2C', '#6BCD59', '#35B778', '#1E9C89', '#25828E',
        '#30678D', '#3E4989', '#472777', '#440154'
    ],
    low=int(min(source.data.get('searches')) / 50000),
    high=int(max(source.data.get('searches')) / 50000))
color_bar = ColorBar(color_mapper=color_mapper,
                     orientation='horizontal',
                     location='bottom_left',
                     scale_alpha=0.7)
plot.add_layout(color_bar)

hover = HoverTool(tooltips=[
    ("Market", "@dim_market"),
    ("Country", "@dim_country_name"),
])
plot.add_tools(PanTool(), WheelZoomTool(), hover)
Exemple #11
0
# range bounds supplied in web mercator coordinates
p = figure(x_range=(-13638000, -13504000),
           y_range=(5967000, 6069000),
           x_axis_type="mercator",
           y_axis_type="mercator",
           title="House Sale Values in King County Washington")
p.add_tile(CARTODBPOSITRON)

source = ColumnDataSource(data=dict(lat=house.coords_x.tolist(),
                                    lon=house.coords_y.tolist(),
                                    size=house.map_dot.tolist(),
                                    color=house.price.tolist(),
                                    legend=house.price.tolist()))

color_mapper = LinearColorMapper(palette="Viridis256", low=78000, high=7070000)

output_notebook()

circle = Circle(x='lat',
                y='lon',
                size='size',
                fill_color={
                    'field': 'color',
                    'transform': color_mapper
                },
                fill_alpha=.6,
                line_color=None)

p.add_glyph(source, circle)
Exemple #12
0
def nvlink(doc):
    max_bw = _get_max_bandwidth()

    tx_fig = figure(title="TX NVLink [B/s]",
                    sizing_mode="stretch_both",
                    y_range=[0, max_bw])
    tx_fig.yaxis.formatter = NumeralTickFormatter(format="0.0 b")
    nvlink_state = _get_nvlink_throughput()
    nvlink_state["tx-ref"] = nvlink_state["tx"].copy()
    left = list(range(ngpus))
    right = [l + 0.8 for l in left]
    source = ColumnDataSource({
        "left": left,
        "right": right,
        "count-tx": [0.0 for i in range(ngpus)],
        "count-rx": [0.0 for i in range(ngpus)],
    })
    mapper = LinearColorMapper(palette=all_palettes["RdYlBu"][4],
                               low=0,
                               high=max_bw)

    tx_fig.quad(
        source=source,
        left="left",
        right="right",
        bottom=0,
        top="count-tx",
        color={
            "field": "count-tx",
            "transform": mapper
        },
    )
    tx_fig.toolbar_location = None

    rx_fig = figure(title="RX NVLink [B/s]",
                    sizing_mode="stretch_both",
                    y_range=[0, max_bw])
    rx_fig.yaxis.formatter = NumeralTickFormatter(format="0.0 b")
    nvlink_state["rx-ref"] = nvlink_state["rx"].copy()

    rx_fig.quad(
        source=source,
        left="left",
        right="right",
        bottom=0,
        top="count-rx",
        color={
            "field": "count-rx",
            "transform": mapper
        },
    )
    rx_fig.toolbar_location = None

    doc.title = "NVLink Utilization Counters"
    doc.add_root(column(tx_fig, rx_fig, sizing_mode="stretch_both"))

    def cb():
        nvlink_state["tx-ref"] = nvlink_state["tx"].copy()
        nvlink_state["rx-ref"] = nvlink_state["rx"].copy()
        src_dict = {}
        nvlink_state.update(_get_nvlink_throughput())
        src_dict["count-tx"] = [
            max(a - b, 0.0) * 5.0
            for (a, b) in zip(nvlink_state["tx"], nvlink_state["tx-ref"])
        ]
        src_dict["count-rx"] = [
            max(a - b, 0.0) * 5.0
            for (a, b) in zip(nvlink_state["rx"], nvlink_state["rx-ref"])
        ]

        source.data.update(src_dict)

    doc.add_periodic_callback(cb, 200)
    def tool_handler(self, doc):
        from bokeh.layouts import row, column, widgetbox
        from bokeh.models import widgets, Spacer
        from bokeh.models.mappers import LinearColorMapper
        from bokeh.plotting import figure

        default_palette = self.default_palette

        x_coords, y_coords = self.arr.coords[
            self.arr.dims[1]], self.arr.coords[self.arr.dims[0]]
        self.app_context.update({
            'data': self.arr,
            'cached_data': {},
            'gamma_cached_data': {},
            'plots': {},
            'data_range': self.arr.T.range(),
            'figures': {},
            'widgets': {},
            'color_maps': {}
        })

        self.app_context['color_maps']['d2'] = LinearColorMapper(
            default_palette,
            low=np.min(self.arr.values),
            high=np.max(self.arr.values),
            nan_color='black')

        self.app_context['color_maps']['curvature'] = LinearColorMapper(
            default_palette,
            low=np.min(self.arr.values),
            high=np.max(self.arr.values),
            nan_color='black')

        self.app_context['color_maps']['raw'] = LinearColorMapper(
            default_palette,
            low=np.min(self.arr.values),
            high=np.max(self.arr.values),
            nan_color='black')

        plots, figures, data_range, cached_data, gamma_cached_data = (
            self.app_context['plots'],
            self.app_context['figures'],
            self.app_context['data_range'],
            self.app_context['cached_data'],
            self.app_context['gamma_cached_data'],
        )

        cached_data['raw'] = self.arr.values
        gamma_cached_data['raw'] = self.arr.values

        figure_kwargs = {
            'tools': ['reset', 'wheel_zoom'],
            'plot_width': self.app_main_size,
            'plot_height': self.app_main_size,
            'min_border': 10,
            'toolbar_location': 'left',
            'x_range': data_range['x'],
            'y_range': data_range['y'],
            'x_axis_location': 'below',
            'y_axis_location': 'right',
        }
        figures['d2'] = figure(title='d2 Spectrum', **figure_kwargs)

        figure_kwargs.update({
            'y_range': self.app_context['figures']['d2'].y_range,
            'x_range': self.app_context['figures']['d2'].x_range,
            'toolbar_location': None,
            'y_axis_location': 'left',
        })

        figures['curvature'] = figure(title='Curvature', **figure_kwargs)
        figures['raw'] = figure(title='Raw Image', **figure_kwargs)

        figures['curvature'].yaxis.major_label_text_font_size = '0pt'

        # TODO add support for color mapper
        plots['d2'] = figures['d2'].image(
            [self.arr.values],
            x=data_range['x'][0],
            y=data_range['y'][0],
            dw=data_range['x'][1] - data_range['x'][0],
            dh=data_range['y'][1] - data_range['y'][0],
            color_mapper=self.app_context['color_maps']['d2'])
        plots['curvature'] = figures['curvature'].image(
            [self.arr.values],
            x=data_range['x'][0],
            y=data_range['y'][0],
            dw=data_range['x'][1] - data_range['x'][0],
            dh=data_range['y'][1] - data_range['y'][0],
            color_mapper=self.app_context['color_maps']['curvature'])
        plots['raw'] = figures['raw'].image(
            [self.arr.values],
            x=data_range['x'][0],
            y=data_range['y'][0],
            dw=data_range['x'][1] - data_range['x'][0],
            dh=data_range['y'][1] - data_range['y'][0],
            color_mapper=self.app_context['color_maps']['raw'])

        smoothing_sliders_by_name = {}
        smoothing_sliders = []  # need one for each axis
        axis_resolution = self.arr.T.stride(generic_dim_names=False)
        for dim in self.arr.dims:
            coords = self.arr.coords[dim]
            resolution = float(axis_resolution[dim])
            high_resolution = len(coords) / 3 * resolution
            low_resolution = resolution

            # could make this axis dependent for more reasonable defaults
            default = 15 * resolution

            if default > high_resolution:
                default = (high_resolution + low_resolution) / 2

            new_slider = widgets.Slider(title='{} Window'.format(dim),
                                        start=low_resolution,
                                        end=high_resolution,
                                        step=resolution,
                                        value=default)
            smoothing_sliders.append(new_slider)
            smoothing_sliders_by_name[dim] = new_slider

        n_smoothing_steps_slider = widgets.Slider(title="Smoothing Steps",
                                                  start=0,
                                                  end=5,
                                                  step=1,
                                                  value=2)
        beta_slider = widgets.Slider(title="β",
                                     start=-8,
                                     end=8,
                                     step=1,
                                     value=0)
        direction_select = widgets.Select(
            options=list(self.arr.dims),
            value='eV' if 'eV' in self.arr.dims else
            self.arr.dims[0],  # preference to energy,
            title='Derivative Direction')
        interleave_smoothing_toggle = widgets.Toggle(
            label='Interleave smoothing with d/dx',
            active=True,
            button_type='primary')
        clamp_spectrum_toggle = widgets.Toggle(
            label='Clamp positive values to 0',
            active=True,
            button_type='primary')
        filter_select = widgets.Select(options=['Gaussian', 'Boxcar'],
                                       value='Boxcar',
                                       title='Type of Filter')

        color_slider = widgets.RangeSlider(start=0,
                                           end=100,
                                           value=(
                                               0,
                                               100,
                                           ),
                                           title='Color Clip')
        gamma_slider = widgets.Slider(start=0.1,
                                      end=4,
                                      value=1,
                                      step=0.1,
                                      title='Gamma')

        # don't need any cacheing here for now, might if this ends up being too slow
        def smoothing_fn(n_passes):
            if n_passes == 0:
                return lambda x: x

            filter_factory = {
                'Gaussian': gaussian_filter,
                'Boxcar': boxcar_filter,
            }.get(filter_select.value, boxcar_filter)

            filter_size = {
                d: smoothing_sliders_by_name[d].value
                for d in self.arr.dims
            }
            return filter_factory(filter_size, n_passes)

        @Debounce(0.25)
        def force_update():
            n_smoothing_steps = n_smoothing_steps_slider.value
            d2_data = self.arr
            if interleave_smoothing_toggle.active:
                f = smoothing_fn(n_smoothing_steps // 2)
                d2_data = d1_along_axis(f(d2_data), direction_select.value)
                f = smoothing_fn(n_smoothing_steps - (n_smoothing_steps // 2))
                d2_data = d1_along_axis(f(d2_data), direction_select.value)

            else:
                f = smoothing_fn(n_smoothing_steps)
                d2_data = d2_along_axis(f(self.arr), direction_select.value)

            d2_data.values[
                d2_data.values != d2_data.
                values] = 0  # remove NaN values until Bokeh fixes NaNs over the wire
            if clamp_spectrum_toggle.active:
                d2_data.values = -d2_data.values
                d2_data.values[d2_data.values < 0] = 0
            cached_data['d2'] = d2_data.values
            gamma_cached_data['d2'] = d2_data.values**gamma_slider.value
            plots['d2'].data_source.data = {'image': [gamma_cached_data['d2']]}

            curv_smoothing_fn = smoothing_fn(n_smoothing_steps)
            smoothed_curvature_data = curv_smoothing_fn(self.arr)
            curvature_data = curvature(smoothed_curvature_data,
                                       self.arr.dims,
                                       beta=beta_slider.value)
            curvature_data.values[
                curvature_data.values != curvature_data.values] = 0
            if clamp_spectrum_toggle.active:
                curvature_data.values = -curvature_data.values
                curvature_data.values[curvature_data.values < 0] = 0

            cached_data['curvature'] = curvature_data.values
            gamma_cached_data[
                'curvature'] = curvature_data.values**gamma_slider.value
            plots['curvature'].data_source.data = {
                'image': [gamma_cached_data['curvature']]
            }
            update_color_slider(color_slider.value)

        # TODO better integrate these, they can share code with the above if we are more careful.
        def take_d2(d2_data):
            n_smoothing_steps = n_smoothing_steps_slider.value
            if interleave_smoothing_toggle.active:
                f = smoothing_fn(n_smoothing_steps // 2)
                d2_data = d1_along_axis(f(d2_data), direction_select.value)
                f = smoothing_fn(n_smoothing_steps - (n_smoothing_steps // 2))
                d2_data = d1_along_axis(f(d2_data), direction_select.value)

            else:
                f = smoothing_fn(n_smoothing_steps)
                d2_data = d2_along_axis(f(self.arr), direction_select.value)

            d2_data.values[
                d2_data.values != d2_data.
                values] = 0  # remove NaN values until Bokeh fixes NaNs over the wire
            if clamp_spectrum_toggle.active:
                d2_data.values = -d2_data.values
                d2_data.values[d2_data.values < 0] = 0

            return d2_data

        def take_curvature(curvature_data, curve_dims):
            curv_smoothing_fn = smoothing_fn(n_smoothing_steps_slider.value)
            smoothed_curvature_data = curv_smoothing_fn(curvature_data)
            curvature_data = curvature(smoothed_curvature_data,
                                       curve_dims,
                                       beta=beta_slider.value)
            curvature_data.values[
                curvature_data.values != curvature_data.values] = 0
            if clamp_spectrum_toggle.active:
                curvature_data.values = -curvature_data.values
                curvature_data.values[curvature_data.values < 0] = 0

            return curvature_data

        # These functions will always be linked to the current context of the curvature tool.
        self.app_context['d2_fn'] = take_d2
        self.app_context['curvature_fn'] = take_curvature

        def force_update_change_wrapper(attr, old, new):
            if old != new:
                force_update()

        def force_update_click_wrapper(event):
            force_update()

        @Debounce(0.1)
        def update_color_slider(new):
            def update_plot(name, data):
                low, high = np.min(data), np.max(data)
                dynamic_range = high - low
                self.app_context['color_maps'][name].update(
                    low=low + new[0] / 100 * dynamic_range,
                    high=low + new[1] / 100 * dynamic_range)

            update_plot('d2', gamma_cached_data['d2'])
            update_plot('curvature', gamma_cached_data['curvature'])
            update_plot('raw', gamma_cached_data['raw'])

        @Debounce(0.1)
        def update_gamma_slider(new):
            gamma_cached_data['d2'] = cached_data['d2']**new
            gamma_cached_data['curvature'] = cached_data['curvature']**new
            gamma_cached_data['raw'] = cached_data['raw']**new
            update_color_slider(color_slider.value)

        def update_color_handler(attr, old, new):
            update_color_slider(new)

        def update_gamma_handler(attr, old, new):
            update_gamma_slider(new)

        layout = column(
            row(
                column(self.app_context['figures']['d2'],
                       interleave_smoothing_toggle, direction_select),
                column(self.app_context['figures']['curvature'], beta_slider,
                       clamp_spectrum_toggle),
                column(self.app_context['figures']['raw'], color_slider,
                       gamma_slider)),
            widgetbox(
                filter_select,
                *smoothing_sliders,
                n_smoothing_steps_slider,
            ),
            Spacer(height=100),
        )

        # Attach event handlers
        for w in (n_smoothing_steps_slider, beta_slider, direction_select,
                  *smoothing_sliders, filter_select):
            w.on_change('value', force_update_change_wrapper)

        interleave_smoothing_toggle.on_click(force_update_click_wrapper)
        clamp_spectrum_toggle.on_click(force_update_click_wrapper)

        color_slider.on_change('value', update_color_handler)
        gamma_slider.on_change('value', update_gamma_handler)

        force_update()

        doc.add_root(layout)
        doc.title = 'Curvature Tool'
Exemple #14
0
    #nx.set_node_attributes(hub_ego, attrs)
    dept_dict = {'dept': [dept_lookup[x] for x in id_dict['id']]}

    graph = from_networkx(G, nx.spring_layout, scale=2, center=(0, 0))

    TOOLS = "pan,wheel_zoom,reset"
    TOOLTIPS = [("PersonID", "@id"), ("DeptID", "@dept"),
                ("Degree", "@degree")]
    plot = figure(title="EU Email Network",
                  x_range=(-1.05, 1.05),
                  y_range=(-1.05, 1.05),
                  tools=TOOLS)
    plot.add_tools(HoverTool(tooltips=TOOLTIPS), TapTool())

    mapper = LinearColorMapper(palette=Viridis256, low=0, high=41)

    graph.selection_policy = NodesAndLinkedEdges()
    graph.inspection_policy = NodesOnly()

    graph.edge_renderer.glyph.line_alpha = 0.08
    graph.edge_renderer.glyph.line_width = 0.2
    graph.edge_renderer.glyph.line_color = 'magenta'

    graph.node_renderer.glyph.fill_color = {
        'field': 'dept',
        'transform': mapper
    }
    graph.node_renderer.glyph.line_width = 0.2
    graph.node_renderer.glyph.fill_alpha = 0.08
    graph.node_renderer.glyph.line_color = 'grey'
def create_figure(df, df_edges, df_populations, source, x_value, y_value,
                  color_value, size_value):
    """
    creates a graph with patient data (and tree structure, if available), works with four dimensions:
    x axis, y axis, color and size of nodes
    :param df: dataframe for visualisation
    :param df_edges: dataframe with edge data (edge index, from, to)
    :param df_populations: daaframe with cell populations data (population_name, color)
    :param source: ColumnDataSource to be visualised
    :param x_value: value of x dimension
    :param y_value: value of y dimension
    :param color_value: value of color dimension
    :param size_value: value of size dimension
    :return: figure p -> graph with visualized data
    """

    new_columns = []

    if not df.empty:

        pop_names = [
            df_populations.iloc[pop_id]['population_name']
            if pop_id != -1 else '???' for pop_id in df['populationID']
        ]
        source.add(pop_names, name='pop_names')

        x_title = x_value.title()
        y_title = y_value.title()

        kw = dict()
        kw['title'] = "%s vs %s" % (x_title, y_title)

        p = figure(
            plot_height=900,
            plot_width=1200,
            tools='pan, box_zoom,reset, wheel_zoom, box_select, tap, save',
            toolbar_location="above",
            **kw)
        p.add_tools(LassoSelectTool(select_every_mousemove=False))

        p.xaxis.axis_label = x_title
        p.yaxis.axis_label = y_title

        # add lines
        if not df_edges.empty:
            lines_from = []
            lines_to = []
            for line in range(0, df_edges.shape[0]):
                lines_from.append([
                    source.data[x_value][df_edges.iloc[line, 1] - 1],
                    source.data[x_value][df_edges.iloc[line, 2] - 1]
                ])
                lines_to.append([
                    source.data[y_value][df_edges.iloc[line, 1] - 1],
                    source.data[y_value][df_edges.iloc[line, 2] - 1]
                ])

            p.multi_line(lines_from, lines_to, line_width=0.5, color='white')

        # mark populations
        line_color = ['white'] * len(df)
        line_width = [1] * len(df)
        if not df_populations.empty:
            line_color = [
                df_populations.iloc[pop_id]['color']
                if pop_id != -1 else 'white' for pop_id in df['populationID']
            ]
            line_width = [5 if lc != 'white' else 1 for lc in line_color]

        source.add(line_color, name='lc')
        source.add(line_width, name='lw')

        if size_value != 'None':
            sizes = [
                hf.scale(value, df[size_value].min(), df[size_value].max())
                if not np.isnan(value) and value != 0 else 10
                for value in df[size_value]
            ]
        else:
            sizes = [25 for _ in df[x_value]]
        source.add(sizes, name='sz')

        if color_value != 'None':
            mapper = LinearColorMapper(
                palette=hf.rainbow_color_map(),
                high=df[color_value].max(),
                # high_color='red',
                low=df[color_value].min(),
                # low_color='blue'
            )
            color_bar = ColorBar(color_mapper=mapper, location=(0, 0))

            renderer = p.circle(x=x_value,
                                y=y_value,
                                color={
                                    'field': color_value,
                                    'transform': mapper
                                },
                                size='sz',
                                line_color="lc",
                                line_width="lw",
                                line_alpha=0.9,
                                alpha=0.6,
                                hover_color='white',
                                hover_alpha=0.5,
                                source=source)
            p.add_layout(color_bar, 'right')

        else:
            renderer = p.circle(x=x_value,
                                y=y_value,
                                size='sz',
                                line_color="lc",
                                line_width="lw",
                                line_alpha=0.9,
                                alpha=0.6,
                                hover_color='white',
                                hover_alpha=0.5,
                                source=source)

        hover = HoverTool(tooltips=[
            ("index", "$index"),
            ("{}".format(size_value), "@{{{}}}".format(size_value)),
            ("{}".format(color_value), "@{{{}}}".format(color_value)),
            ("population", "@pop_names"),
            ("(x,y)", "($x, $y)"),
        ],
                          renderers=[renderer])

        p.add_tools(hover)
        draw_tool = PointDrawTool(renderers=[renderer], add=False)
        p.add_tools(draw_tool)
        p.toolbar.active_tap = draw_tool

        new_columns = [
            TableColumn(field=x_value, title=x_value),
            TableColumn(field=y_value, title=y_value),
            TableColumn(field=color_value, title=color_value),
            TableColumn(field=size_value, title=size_value),
            TableColumn(field='pop_names', title="population"),
        ]
        return p, new_columns

    p = figure(
        plot_height=900,
        plot_width=1200,
        tools='pan, box_zoom,reset, wheel_zoom, box_select, lasso_select,tap',
        toolbar_location="above")
    return p, new_columns
Exemple #16
0
def nvlink(doc):

    import subprocess as sp

    # Use device-0/link-0 to get "upper bound"
    counter = 1
    nlinks = pynvml.NVML_NVLINK_MAX_LINKS
    nvlink_ver = pynvml.nvmlDeviceGetNvLinkVersion(gpu_handles[0], 0)
    nvlink_link_bw = {
        # Keys = NVLink Version, Values = Max Link BW (per direction)
        # [Note: Using specs at https://en.wikichip.org/wiki/nvidia/nvlink]
        1: 20.0 * GB,  # GB/s
        2: 25.0 * GB,  # GB/s
    }
    # Max NVLink Throughput = BW-per-link * nlinks
    max_bw = nlinks * nvlink_link_bw.get(nvlink_ver, 25.0 * GB)

    # nvmlDeviceSetNvLinkUtilizationControl seems limited, using smi:
    sp.call([
        "nvidia-smi",
        "nvlink",
        "--setcontrol",
        str(counter) + "bz",  # Get output in bytes
    ])

    tx_fig = figure(title="TX NVLink [B/s]",
                    sizing_mode="stretch_both",
                    y_range=[0, max_bw])
    tx_fig.yaxis.formatter = NumeralTickFormatter(format="0.0 b")
    nvlink_state = {}
    nvlink_state["tx"] = [
        sum([
            pynvml.nvmlDeviceGetNvLinkUtilizationCounter(
                gpu_handles[i], j, counter)["tx"] for j in range(nlinks)
        ]) for i in range(ngpus)
    ]
    nvlink_state["tx-ref"] = nvlink_state["tx"].copy()
    left = list(range(ngpus))
    right = [l + 0.8 for l in left]
    source = ColumnDataSource({
        "left": left,
        "right": right,
        "count-tx": [0.0 for i in range(ngpus)],
        "count-rx": [0.0 for i in range(ngpus)],
    })
    mapper = LinearColorMapper(palette=all_palettes["RdYlBu"][4],
                               low=0,
                               high=max_bw)

    tx_fig.quad(
        source=source,
        left="left",
        right="right",
        bottom=0,
        top="count-tx",
        color={
            "field": "count-tx",
            "transform": mapper
        },
    )
    tx_fig.toolbar_location = None

    rx_fig = figure(title="RX NVLink [B/s]",
                    sizing_mode="stretch_both",
                    y_range=[0, max_bw])
    rx_fig.yaxis.formatter = NumeralTickFormatter(format="0.0 b")
    nvlink_state["rx"] = [
        sum([
            pynvml.nvmlDeviceGetNvLinkUtilizationCounter(
                gpu_handles[i], j, counter)["rx"] for j in range(nlinks)
        ]) for i in range(ngpus)
    ]
    nvlink_state["rx-ref"] = nvlink_state["rx"].copy()

    rx_fig.quad(
        source=source,
        left="left",
        right="right",
        bottom=0,
        top="count-rx",
        color={
            "field": "count-rx",
            "transform": mapper
        },
    )
    rx_fig.toolbar_location = None

    doc.title = "NVLink Utilization Counters"
    doc.add_root(column(tx_fig, rx_fig, sizing_mode="stretch_both"))

    def cb():
        nvlink_state["tx-ref"] = nvlink_state["tx"].copy()
        nvlink_state["rx-ref"] = nvlink_state["rx"].copy()
        src_dict = {}
        nvlink_state["tx"] = [
            sum([
                pynvml.nvmlDeviceGetNvLinkUtilizationCounter(
                    gpu_handles[i], j, counter)["tx"] for j in range(nlinks)
            ]) for i in range(ngpus)
        ]
        nvlink_state["rx"] = [
            sum([
                pynvml.nvmlDeviceGetNvLinkUtilizationCounter(
                    gpu_handles[i], j, counter)["rx"] for j in range(nlinks)
            ]) for i in range(ngpus)
        ]
        src_dict["count-tx"] = [
            max(a - b, 0.0) * 5.0
            for (a, b) in zip(nvlink_state["tx"], nvlink_state["tx-ref"])
        ]
        src_dict["count-rx"] = [
            max(a - b, 0.0) * 5.0
            for (a, b) in zip(nvlink_state["rx"], nvlink_state["rx-ref"])
        ]

        source.data.update(src_dict)

    doc.add_periodic_callback(cb, 200)
    def tool_handler(self, doc):
        from bokeh.layouts import row, column, widgetbox, Spacer
        from bokeh.models.mappers import LinearColorMapper
        from bokeh.models import widgets
        from bokeh.models.widgets.markups import Div
        from bokeh.plotting import figure

        self.arr = self.arr.copy(deep=True)

        if not isinstance(self.arr, xr.Dataset):
            self.use_dataset = False

        residual = None
        if self.use_dataset:
            raw_data = self.arr.data
            raw_data.values[np.isnan(raw_data.values)] = 0
            fit_results = self.arr.results
            residual = self.arr.residual
            residual.values[np.isnan(residual.values)] = 0
        else:
            raw_data = self.arr.attrs['original_data']
            fit_results = self.arr

        fit_direction = [d for d in raw_data.dims if d not in fit_results.dims]
        fit_direction = fit_direction[0]

        two_dimensional = False
        if len(raw_data.dims) != 2:
            two_dimensional = True
            x_coords, y_coords = fit_results.coords[
                fit_results.dims[0]], fit_results.coords[fit_results.dims[1]]
            z_coords = raw_data.coords[fit_direction]
        else:
            x_coords, y_coords = raw_data.coords[
                raw_data.dims[0]], raw_data.coords[raw_data.dims[1]]

        if two_dimensional:
            self.settings['palette'] = 'coolwarm'
        default_palette = self.default_palette

        self.app_context.update({
            'data': raw_data,
            'fits': fit_results,
            'residual': residual,
            'original': self.arr,
            'data_range': {
                'x': (np.min(x_coords.values), np.max(x_coords.values)),
                'y': (np.min(y_coords.values), np.max(y_coords.values)),
            }
        })
        if two_dimensional:
            self.app_context['data_range']['z'] = (np.min(z_coords.values),
                                                   np.max(z_coords.values))

        figures, plots, app_widgets = self.app_context['figures'], self.app_context['plots'],\
                                      self.app_context['widgets']

        self.cursor_dims = raw_data.dims
        if two_dimensional:
            self.cursor = [
                np.mean(self.data_range['x']),
                np.mean(self.data_range['y']),
                np.mean(self.data_range['z'])
            ]
        else:
            self.cursor = [
                np.mean(self.data_range['x']),
                np.mean(self.data_range['y'])
            ]

        app_widgets['fit_info_div'] = Div(text='')

        self.app_context['color_maps']['main'] = LinearColorMapper(
            default_palette,
            low=np.min(raw_data.values),
            high=np.max(raw_data.values),
            nan_color='black')

        main_tools = ["wheel_zoom", "tap", "reset", "save"]
        main_title = 'Fit Inspection Tool: WARNING Unidentified'

        try:
            main_title = 'Fit Inspection Tool: {}'.format(
                raw_data.S.label[:60])
        except:
            pass

        figures['main'] = figure(tools=main_tools,
                                 plot_width=self.app_main_size,
                                 plot_height=self.app_main_size,
                                 min_border=10,
                                 min_border_left=50,
                                 toolbar_location='left',
                                 x_axis_location='below',
                                 y_axis_location='right',
                                 title=main_title,
                                 x_range=self.data_range['x'],
                                 y_range=self.app_context['data_range']['y'])
        figures['main'].xaxis.axis_label = raw_data.dims[0]
        figures['main'].yaxis.axis_label = raw_data.dims[1]
        figures['main'].toolbar.logo = None
        figures['main'].background_fill_color = "#fafafa"

        data_for_main = raw_data
        if two_dimensional:
            data_for_main = data_for_main.sel(**dict(
                [[fit_direction, self.cursor[2]]]),
                                              method='nearest')
        plots['main'] = figures['main'].image(
            [data_for_main.values.T],
            x=self.app_context['data_range']['x'][0],
            y=self.app_context['data_range']['y'][0],
            dw=self.app_context['data_range']['x'][1] -
            self.app_context['data_range']['x'][0],
            dh=self.app_context['data_range']['y'][1] -
            self.app_context['data_range']['y'][0],
            color_mapper=self.app_context['color_maps']['main'])

        band_centers = [b.center for b in fit_results.F.bands.values()]
        bands_xs = [b.coords[b.dims[0]].values for b in band_centers]
        bands_ys = [b.values for b in band_centers]
        if fit_results.dims[0] == raw_data.dims[1]:
            bands_ys, bands_xs = bands_xs, bands_ys
        plots['band_locations'] = figures['main'].multi_line(
            xs=bands_xs,
            ys=bands_ys,
            line_color='white',
            line_width=1,
            line_dash='dashed')

        # add cursor lines
        cursor_lines = self.add_cursor_lines(figures['main'])

        # marginals
        if not two_dimensional:
            figures['bottom'] = figure(plot_width=self.app_main_size,
                                       plot_height=self.app_marginal_size,
                                       min_border=10,
                                       title=None,
                                       x_range=figures['main'].x_range,
                                       x_axis_location='above',
                                       toolbar_location=None,
                                       tools=[])
        else:
            figures['bottom'] = Spacer(width=self.app_main_size,
                                       height=self.app_marginal_size)

        right_y_range = figures['main'].y_range
        if two_dimensional:
            right_y_range = self.data_range['z']

        figures['right'] = figure(plot_width=self.app_marginal_size,
                                  plot_height=self.app_main_size,
                                  min_border=10,
                                  title=None,
                                  y_range=right_y_range,
                                  y_axis_location='left',
                                  toolbar_location=None,
                                  tools=[])

        marginal_line_width = 2
        if not two_dimensional:
            bottom_data = raw_data.sel(**dict(
                [[raw_data.dims[1], self.cursor[1]]]),
                                       method='nearest')
            right_data = raw_data.sel(**dict(
                [[raw_data.dims[0], self.cursor[0]]]),
                                      method='nearest')

            plots['bottom'] = figures['bottom'].line(
                x=bottom_data.coords[raw_data.dims[0]].values,
                y=bottom_data.values,
                line_width=marginal_line_width)
            plots['bottom_residual'] = figures['bottom'].line(
                x=[], y=[], line_color='red', line_width=marginal_line_width)
            plots['bottom_fit'] = figures['bottom'].line(
                x=[],
                y=[],
                line_color='blue',
                line_width=marginal_line_width,
                line_dash='dashed')
            plots['bottom_init_fit'] = figures['bottom'].line(
                x=[],
                y=[],
                line_color='green',
                line_width=marginal_line_width,
                line_dash='dotted')

            plots['right'] = figures['right'].line(
                y=right_data.coords[raw_data.dims[1]].values,
                x=right_data.values,
                line_width=marginal_line_width)
            plots['right_residual'] = figures['right'].line(
                x=[], y=[], line_color='red', line_width=marginal_line_width)
            plots['right_fit'] = figures['right'].line(
                x=[],
                y=[],
                line_color='blue',
                line_width=marginal_line_width,
                line_dash='dashed')
            plots['right_init_fit'] = figures['right'].line(
                x=[],
                y=[],
                line_color='green',
                line_width=marginal_line_width,
                line_dash='dotted')
        else:
            right_data = raw_data.sel(**{
                k: v
                for k, v in self.cursor_dict.items() if k != fit_direction
            },
                                      method='nearest')
            plots['right'] = figures['right'].line(
                y=right_data.coords[right_data.dims[0]].values,
                x=right_data.values,
                line_width=marginal_line_width)
            plots['right_residual'] = figures['right'].line(
                x=[], y=[], line_color='red', line_width=marginal_line_width)
            plots['right_fit'] = figures['right'].line(
                x=[],
                y=[],
                line_color='blue',
                line_width=marginal_line_width,
                line_dash='dashed')
            plots['right_init_fit'] = figures['right'].line(
                x=[],
                y=[],
                line_color='green',
                line_width=marginal_line_width,
                line_dash='dotted')

        def on_change_main_view(attr, old, data_source):
            self.selected_data = data_source
            data = None
            if data_source == 'data':
                data = raw_data.sel(**{
                    k: v
                    for k, v in self.cursor_dict.items() if k == fit_direction
                },
                                    method='nearest')
            elif data_source == 'residual':
                data = residual.sel(**{
                    k: v
                    for k, v in self.cursor_dict.items() if k == fit_direction
                },
                                    method='nearest')
            elif two_dimensional:
                data = fit_results.F.s(data_source)
                data.values[np.isnan(data.values)] = 0

            if data is not None:
                if self.remove_outliers:
                    data = data.T.clean_outliers(clip=self.outlier_clip)

                plots['main'].data_source.data = {
                    'image': [data.values.T],
                }
                update_main_colormap(None, None, main_color_range_slider.value)

        def update_fit_display():
            target = 'right'
            if fit_results.dims[0] == raw_data.dims[1]:
                target = 'bottom'

            if two_dimensional:
                target = 'right'
                current_fit = fit_results.sel(**{
                    k: v
                    for k, v in self.cursor_dict.items() if k != fit_direction
                },
                                              method='nearest').item()
                coord_vals = raw_data.coords[fit_direction].values
            else:
                current_fit = fit_results.sel(**dict([[
                    fit_results.dims[0],
                    self.cursor[0 if target == 'right' else 1]
                ]]),
                                              method='nearest').item()
                coord_vals = raw_data.coords[
                    raw_data.dims[0 if target == 'bottom' else 1]].values

            if current_fit is not None:
                app_widgets['fit_info_div'].text = current_fit._repr_html_(
                    short=True)  # pylint: disable=protected-access
            else:
                app_widgets['fit_info_div'].text = 'No fit here.'
                plots['{}_residual'.format(target)].data_source.data = {
                    'x': [],
                    'y': [],
                }
                plots['{}_fit'.format(target)].data_source.data = {
                    'x': [],
                    'y': [],
                }
                plots['{}_init_fit'.format(target)].data_source.data = {
                    'x': [],
                    'y': [],
                }
                return

            if target == 'bottom':
                residual_x = coord_vals
                residual_y = current_fit.residual
                init_fit_x = coord_vals
                init_fit_y = current_fit.init_fit
                fit_x = coord_vals
                fit_y = current_fit.best_fit
            else:
                residual_y = coord_vals
                residual_x = current_fit.residual
                init_fit_y = coord_vals
                init_fit_x = current_fit.init_fit
                fit_y = coord_vals
                fit_x = current_fit.best_fit

            plots['{}_residual'.format(target)].data_source.data = {
                'x': residual_x,
                'y': residual_y,
            }
            plots['{}_fit'.format(target)].data_source.data = {
                'x': fit_x,
                'y': fit_y,
            }
            plots['{}_init_fit'.format(target)].data_source.data = {
                'x': init_fit_x,
                'y': init_fit_y,
            }

        def click_right_marginal(event):
            self.cursor = [self.cursor[0], self.cursor[1], event.y]
            on_change_main_view(None, None, self.selected_data)

        def click_main_image(event):
            if two_dimensional:
                self.cursor = [event.x, event.y, self.cursor[2]]
            else:
                self.cursor = [event.x, event.y]

            if not two_dimensional:
                right_marginal_data = raw_data.sel(**dict(
                    [[raw_data.dims[0], self.cursor[0]]]),
                                                   method='nearest')
                bottom_marginal_data = raw_data.sel(**dict(
                    [[raw_data.dims[1], self.cursor[1]]]),
                                                    method='nearest')
                plots['bottom'].data_source.data = {
                    'x': bottom_marginal_data.coords[raw_data.dims[0]].values,
                    'y': bottom_marginal_data.values,
                }
            else:
                right_marginal_data = raw_data.sel(**{
                    k: v
                    for k, v in self.cursor_dict.items() if k != fit_direction
                },
                                                   method='nearest')

            plots['right'].data_source.data = {
                'y':
                right_marginal_data.coords[right_marginal_data.dims[0]].values,
                'x': right_marginal_data.values,
            }

            update_fit_display()

        def on_change_outlier_clip(attr, old, new):
            self.outlier_clip = new
            on_change_main_view(None, None, self.selected_data)

        def set_remove_outliers(should_remove_outliers):
            if self.remove_outliers != should_remove_outliers:
                self.remove_outliers = should_remove_outliers

                on_change_main_view(None, None, self.selected_data)

        update_main_colormap = self.update_colormap_for('main')
        MAIN_CONTENT_OPTIONS = [
            ('Residual', 'residual'),
            ('Data', 'data'),
        ]

        if two_dimensional:
            available_parameters = fit_results.F.parameter_names

            for param_name in available_parameters:
                MAIN_CONTENT_OPTIONS.append((
                    param_name,
                    param_name,
                ))

        remove_outliers_toggle = widgets.Toggle(label='Remove Outliers',
                                                button_type='primary',
                                                active=self.remove_outliers)
        remove_outliers_toggle.on_click(set_remove_outliers)

        outlier_clip_slider = widgets.Slider(title='Clip',
                                             start=0,
                                             end=10,
                                             value=self.outlier_clip,
                                             callback_throttle=150,
                                             step=0.2)
        outlier_clip_slider.on_change('value', on_change_outlier_clip)

        main_content_select = widgets.Dropdown(label='Main Content',
                                               button_type='primary',
                                               menu=MAIN_CONTENT_OPTIONS)
        main_content_select.on_change('value', on_change_main_view)

        # Widgety things
        main_color_range_slider = widgets.RangeSlider(start=0,
                                                      end=100,
                                                      value=(
                                                          0,
                                                          100,
                                                      ),
                                                      title='Color Range')

        # Attach callbacks
        main_color_range_slider.on_change('value', update_main_colormap)
        figures['main'].on_event(events.Tap, click_main_image)
        if two_dimensional:
            figures['right'].on_event(events.Tap, click_right_marginal)

        layout = row(
            column(figures['main'], figures.get('bottom')),
            column(figures['right'], app_widgets['fit_info_div']),
            column(
                widgetbox(*[
                    widget for widget in [
                        self._cursor_info,
                        main_color_range_slider,
                        main_content_select,
                        remove_outliers_toggle if two_dimensional else None,
                        outlier_clip_slider if two_dimensional else None,
                    ] if widget is not None
                ]), ))

        update_fit_display()

        doc.add_root(layout)
        doc.title = 'Band Tool'
Exemple #18
0
    def tool_handler(self, doc):
        from bokeh import events
        from bokeh.layouts import row, column, widgetbox
        from bokeh.models.mappers import LinearColorMapper
        from bokeh.models import widgets
        from bokeh.plotting import figure

        if len(self.arr.shape) != 2:
            raise AnalysisError(
                'Cannot use the band tool on non image-like spectra')

        self.data_for_display = self.arr
        x_coords, y_coords = self.data_for_display.coords[
            self.data_for_display.dims[0]], self.data_for_display.coords[
                self.data_for_display.dims[1]]

        default_palette = self.default_palette

        self.app_context.update({
            'data': self.arr,
            'data_range': {
                'x': (np.min(x_coords.values), np.max(x_coords.values)),
                'y': (np.min(y_coords.values), np.max(y_coords.values)),
            },
        })

        figures, plots = self.app_context['figures'], self.app_context['plots']

        self.cursor = [
            np.mean(self.data_range['x']),
            np.mean(self.data_range['y'])
        ]

        self.color_maps['main'] = LinearColorMapper(
            default_palette,
            low=np.min(self.data_for_display.values),
            high=np.max(self.data_for_display.values),
            nan_color='black')

        main_tools = ["wheel_zoom", "tap", "reset"]
        main_title = '{} Tool: WARNING Unidentified'.format(
            self.analysis_fn.__name__)

        try:
            main_title = '{} Tool: {}'.format(
                self.analysis_fn.__name__, self.data_for_display.S.label[:60])
        except:
            pass

        figures['main'] = figure(tools=main_tools,
                                 plot_width=self.app_main_size,
                                 plot_height=self.app_main_size,
                                 min_border=10,
                                 min_border_left=50,
                                 toolbar_location='left',
                                 x_axis_location='below',
                                 y_axis_location='right',
                                 title=main_title,
                                 x_range=self.data_range['x'],
                                 y_range=self.data_range['y'])
        figures['main'].xaxis.axis_label = self.data_for_display.dims[0]
        figures['main'].yaxis.axis_label = self.data_for_display.dims[1]
        figures['main'].toolbar.logo = None
        figures['main'].background_fill_color = "#fafafa"
        plots['main'] = figures['main'].image(
            [self.data_for_display.values.T],
            x=self.app_context['data_range']['x'][0],
            y=self.app_context['data_range']['y'][0],
            dw=self.app_context['data_range']['x'][1] -
            self.app_context['data_range']['x'][0],
            dh=self.app_context['data_range']['y'][1] -
            self.app_context['data_range']['y'][0],
            color_mapper=self.app_context['color_maps']['main'])

        # Create the bottom marginal plot
        bottom_marginal = self.data_for_display.sel(**dict(
            [[self.data_for_display.dims[1], self.cursor[1]]]),
                                                    method='nearest')
        bottom_marginal_original = self.arr.sel(**dict(
            [[self.data_for_display.dims[1], self.cursor[1]]]),
                                                method='nearest')
        figures['bottom_marginal'] = figure(
            plot_width=self.app_main_size,
            plot_height=200,
            title=None,
            x_range=figures['main'].x_range,
            y_range=(np.min(bottom_marginal.values),
                     np.max(bottom_marginal.values)),
            x_axis_location='above',
            toolbar_location=None,
            tools=[])
        plots['bottom_marginal'] = figures['bottom_marginal'].line(
            x=bottom_marginal.coords[self.data_for_display.dims[0]].values,
            y=bottom_marginal.values)
        plots['bottom_marginal_original'] = figures['bottom_marginal'].line(
            x=bottom_marginal_original.coords[self.arr.dims[0]].values,
            y=bottom_marginal_original.values,
            line_color='red')

        # Create the right marginal plot
        right_marginal = self.data_for_display.sel(**dict(
            [[self.data_for_display.dims[0], self.cursor[0]]]),
                                                   method='nearest')
        right_marginal_original = self.arr.sel(**dict(
            [[self.data_for_display.dims[0], self.cursor[0]]]),
                                               method='nearest')
        figures['right_marginal'] = figure(
            plot_width=200,
            plot_height=self.app_main_size,
            title=None,
            y_range=figures['main'].y_range,
            x_range=(np.min(right_marginal.values),
                     np.max(right_marginal.values)),
            y_axis_location='left',
            toolbar_location=None,
            tools=[])
        plots['right_marginal'] = figures['right_marginal'].line(
            y=right_marginal.coords[self.data_for_display.dims[1]].values,
            x=right_marginal.values)
        plots['right_marginal_original'] = figures['right_marginal'].line(
            y=right_marginal_original.coords[
                self.data_for_display.dims[1]].values,
            x=right_marginal_original.values,
            line_color='red')

        # add lines
        self.add_cursor_lines(figures['main'])
        _ = figures['main'].multi_line(xs=[],
                                       ys=[],
                                       line_color='white',
                                       line_width=1)  # band lines

        # prep the widgets for the analysis function
        signature = inspect.signature(self.analysis_fn)

        # drop the first which has to be the input data, we can revisit later if this is too limiting
        parameter_names = list(signature.parameters)[1:]
        named_widgets = dict(zip(parameter_names, self.widget_specification))
        built_widgets = {}

        def update_marginals():
            right_marginal_data = self.data_for_display.sel(**dict(
                [[self.data_for_display.dims[0], self.cursor[0]]]),
                                                            method='nearest')
            bottom_marginal_data = self.data_for_display.sel(**dict(
                [[self.data_for_display.dims[1], self.cursor[1]]]),
                                                             method='nearest')
            plots['bottom_marginal'].data_source.data = {
                'x': bottom_marginal_data.coords[
                    self.data_for_display.dims[0]].values,
                'y': bottom_marginal_data.values,
            }
            plots['right_marginal'].data_source.data = {
                'y': right_marginal_data.coords[
                    self.data_for_display.dims[1]].values,
                'x': right_marginal_data.values,
            }

            right_marginal_data = self.arr.sel(**dict(
                [[self.data_for_display.dims[0], self.cursor[0]]]),
                                               method='nearest')
            bottom_marginal_data = self.arr.sel(**dict(
                [[self.data_for_display.dims[1], self.cursor[1]]]),
                                                method='nearest')
            plots['bottom_marginal_original'].data_source.data = {
                'x': bottom_marginal_data.coords[
                    self.data_for_display.dims[0]].values,
                'y': bottom_marginal_data.values,
            }
            plots['right_marginal_original'].data_source.data = {
                'y': right_marginal_data.coords[
                    self.data_for_display.dims[1]].values,
                'x': right_marginal_data.values,
            }
            figures['bottom_marginal'].y_range.start = np.min(
                bottom_marginal_data.values)
            figures['bottom_marginal'].y_range.end = np.max(
                bottom_marginal_data.values)
            figures['right_marginal'].x_range.start = np.min(
                right_marginal_data.values)
            figures['right_marginal'].x_range.end = np.max(
                right_marginal_data.values)

        def click_main_image(event):
            self.cursor = [event.x, event.y]
            update_marginals()

        error_msg = widgets.Div(text='')

        @Debounce(0.25)
        def update_data_for_display():
            try:
                self.data_for_display = self.analysis_fn(
                    self.arr, *[
                        built_widgets[p].value for p in parameter_names
                        if p in built_widgets
                    ])
                error_msg.text = ''
            except Exception as e:
                error_msg.text = '{}'.format(e)

            # flush + update
            update_marginals()
            plots['main'].data_source.data = {
                'image': [self.data_for_display.values.T]
            }

        def update_data_change_wrapper(attr, old, new):
            if old != new:
                update_data_for_display()

        for parameter_name in named_widgets.keys():
            specification = named_widgets[parameter_name]

            widget = None
            if specification == int:
                widget = widgets.Slider(start=-20,
                                        end=20,
                                        value=0,
                                        title=parameter_name)
            if specification == float:
                widget = widgets.Slider(start=-20,
                                        end=20,
                                        value=0.,
                                        step=0.1,
                                        title=parameter_name)

            if widget is not None:
                built_widgets[parameter_name] = widget
                widget.on_change('value', update_data_change_wrapper)

        update_main_colormap = self.update_colormap_for('main')

        self.app_context['run'] = lambda x: x

        main_color_range_slider = widgets.RangeSlider(start=0,
                                                      end=100,
                                                      value=(
                                                          0,
                                                          100,
                                                      ),
                                                      title='Color Range')

        # Attach callbacks
        main_color_range_slider.on_change('value', update_main_colormap)

        figures['main'].on_event(events.Tap, click_main_image)

        layout = row(
            column(figures['main'], figures['bottom_marginal']),
            column(figures['right_marginal']),
            column(
                widgetbox(*[
                    built_widgets[p] for p in parameter_names
                    if p in built_widgets
                ]),
                widgetbox(
                    self._cursor_info,
                    main_color_range_slider,
                    error_msg,
                )))

        doc.add_root(layout)
        doc.title = 'Band Tool'
Exemple #19
0
    def __init__(
        self,
        ydeg,
        npix,
        npts,
        nmaps,
        throttle_time,
        nosmooth,
        gp,
        sample_function,
    ):
        # Settings
        self.ydeg = ydeg
        self.npix = npix
        self.npts = npts
        self.throttle_time = throttle_time
        self.nosmooth = nosmooth
        self.nmaps = nmaps
        self.gp = gp

        # Design matrices
        self.A_I = get_intensity_design_matrix(ydeg, npix)
        self.A_F = get_flux_design_matrix(ydeg, npts)

        def sample_ylm(r, mu_l, sigma_l, c, n):
            # Avoid issues at the boundaries
            if mu_l == 0:
                mu_l = 1e-2
            elif mu_l == 90:
                mu_l = 90 - 1e-2
            a, b = gauss2beta(mu_l, sigma_l)
            return sample_function(r, a, b, c, n)

        self.sample_ylm = sample_ylm

        # Draw three samples from the default distr
        self.ylm = self.sample_ylm(
            params["size"]["r"]["value"],
            params["latitude"]["mu"]["value"],
            params["latitude"]["sigma"]["value"],
            params["contrast"]["c"]["value"],
            params["contrast"]["n"]["value"],
        )[0]

        # Plot the GP ylm samples
        self.color_mapper = LinearColorMapper(palette="Plasma256",
                                              nan_color="white",
                                              low=0.5,
                                              high=1.2)
        self.moll_plot = [None for i in range(self.nmaps)]
        self.moll_source = [
            ColumnDataSource(data=dict(image=[
                1.0 +
                (self.A_I @ self.ylm[i]).reshape(self.npix, 2 * self.npix)
            ])) for i in range(self.nmaps)
        ]
        eps = 0.1
        epsp = 0.02
        xe = np.linspace(-2, 2, 300)
        ye = 0.5 * np.sqrt(4 - xe**2)
        for i in range(self.nmaps):
            self.moll_plot[i] = figure(
                plot_width=280,
                plot_height=130,
                toolbar_location=None,
                x_range=(-2 - eps, 2 + eps),
                y_range=(-1 - eps / 2, 1 + eps / 2),
            )
            self.moll_plot[i].axis.visible = False
            self.moll_plot[i].grid.visible = False
            self.moll_plot[i].outline_line_color = None
            self.moll_plot[i].image(
                image="image",
                x=-2,
                y=-1,
                dw=4 + epsp,
                dh=2 + epsp / 2,
                color_mapper=self.color_mapper,
                source=self.moll_source[i],
            )
            self.moll_plot[i].toolbar.active_drag = None
            self.moll_plot[i].toolbar.active_scroll = None
            self.moll_plot[i].toolbar.active_tap = None

        # Plot lat/lon grid
        lat_lines = get_latitude_lines()
        lon_lines = get_longitude_lines()
        for i in range(self.nmaps):
            for x, y in lat_lines:
                self.moll_plot[i].line(x,
                                       y,
                                       line_width=1,
                                       color="black",
                                       alpha=0.25)
            for x, y in lon_lines:
                self.moll_plot[i].line(x,
                                       y,
                                       line_width=1,
                                       color="black",
                                       alpha=0.25)
            self.moll_plot[i].line(xe,
                                   ye,
                                   line_width=3,
                                   color="black",
                                   alpha=1)
            self.moll_plot[i].line(xe,
                                   -ye,
                                   line_width=3,
                                   color="black",
                                   alpha=1)

        # Colorbar slider
        self.slider = RangeSlider(
            start=0,
            end=1.5,
            step=0.01,
            value=(0.5, 1.2),
            orientation="horizontal",
            show_value=False,
            css_classes=["colorbar-slider"],
            direction="ltr",
            title="cmap",
        )
        self.slider.on_change("value", self.slider_callback)

        # Buttons
        self.seed_button = Button(
            label="re-seed",
            button_type="default",
            css_classes=["seed-button"],
            sizing_mode="fixed",
            height=30,
            width=75,
        )
        self.seed_button.on_click(self.seed_callback)

        self.smooth_button = Toggle(
            label="smooth",
            button_type="default",
            css_classes=["smooth-button"],
            sizing_mode="fixed",
            height=30,
            width=75,
            active=False,
        )
        self.smooth_button.disabled = bool(self.nosmooth)
        self.smooth_button.on_click(self.smooth_callback)

        self.auto_button = Toggle(
            label="auto",
            button_type="default",
            css_classes=["auto-button"],
            sizing_mode="fixed",
            height=30,
            width=75,
            active=True,
        )

        self.reset_button = Button(
            label="reset",
            button_type="default",
            css_classes=["reset-button"],
            sizing_mode="fixed",
            height=30,
            width=75,
        )
        self.reset_button.on_click(self.reset_callback)

        # Light curve samples
        self.flux_plot = [None for i in range(self.nmaps)]
        self.flux_source = [
            ColumnDataSource(data=dict(
                xs=[np.linspace(0, 2, npts) for j in range(6)],
                ys=[fluxnorm(self.A_F[j] @ self.ylm[i]) for j in range(6)],
                color=[Plasma6[5 - j] for j in range(6)],
                inc=[15, 30, 45, 60, 75, 90],
            )) for i in range(self.nmaps)
        ]
        for i in range(self.nmaps):
            self.flux_plot[i] = figure(
                toolbar_location=None,
                x_range=(0, 2),
                y_range=None,
                min_border_left=50,
                plot_height=400,
            )
            if i == 0:
                self.flux_plot[i].yaxis.axis_label = "flux [ppt]"
                self.flux_plot[i].yaxis.axis_label_text_font_style = "normal"
            self.flux_plot[i].xaxis.axis_label = "rotational phase"
            self.flux_plot[i].xaxis.axis_label_text_font_style = "normal"
            self.flux_plot[i].outline_line_color = None
            self.flux_plot[i].multi_line(
                xs="xs",
                ys="ys",
                line_color="color",
                source=self.flux_source[i],
            )
            self.flux_plot[i].toolbar.active_drag = None
            self.flux_plot[i].toolbar.active_scroll = None
            self.flux_plot[i].toolbar.active_tap = None
            self.flux_plot[i].yaxis.major_label_orientation = np.pi / 4
            self.flux_plot[i].xaxis.axis_label_text_font_size = "8pt"
            self.flux_plot[i].xaxis.major_label_text_font_size = "8pt"
            self.flux_plot[i].yaxis.axis_label_text_font_size = "8pt"
            self.flux_plot[i].yaxis.major_label_text_font_size = "8pt"

        # Javascript callback to update light curves & images
        self.A_F_source = ColumnDataSource(data=dict(A_F=self.A_F))
        self.A_I_source = ColumnDataSource(data=dict(A_I=self.A_I))
        self.ylm_source = ColumnDataSource(data=dict(ylm=self.ylm))
        callback = CustomJS(
            args=dict(
                A_F_source=self.A_F_source,
                A_I_source=self.A_I_source,
                ylm_source=self.ylm_source,
                flux_source=self.flux_source,
                moll_source=self.moll_source,
            ),
            code="""
            var A_F = A_F_source.data['A_F'];
            var A_I = A_I_source.data['A_I'];
            var ylm = ylm_source.data['ylm'];
            var i, j, k, l, m, n;
            for (n = 0; n < {nmax}; n++) {{

                // Update the light curves
                var flux = flux_source[n].data['ys'];
                for (l = 0; l < {lmax}; l++) {{
                    for (m = 0; m < {mmax}; m++) {{
                        flux[l][m] = 0.0;
                        for (k = 0; k < {kmax}; k++) {{
                            flux[l][m] += A_F[{kmax} * ({mmax} * l + m) + k] * ylm[{kmax} * n + k];
                        }}
                    }}
                    // Normalize
                    var mean = flux[l].reduce((previous, current) => current += previous) / {mmax};
                    for (m = 0; m < {mmax}; m++) {{
                        flux[l][m] = 1e3 * ((1 + flux[l][m]) / (1 + mean) - 1)
                    }}
                }}
                flux_source[n].change.emit();

                // Update the images
                var image = moll_source[n].data['image'][0];
                for (i = 0; i < {imax}; i++) {{
                    for (j = 0; j < {jmax}; j++) {{
                        image[{jmax} * i + j] = 1.0;
                        for (k = 0; k < {kmax}; k++) {{
                            image[{jmax} * i + j] += A_I[{kmax} * ({jmax} * i + j) + k] * ylm[{kmax} * n + k];
                        }}
                    }}
                }}
                moll_source[n].change.emit();

            }}
            """.format(
                imax=self.npix,
                jmax=2 * self.npix,
                kmax=(self.ydeg + 1)**2,
                nmax=self.nmaps,
                lmax=self.A_F.shape[0],
                mmax=self.npts,
            ),
        )
        self.js_dummy = self.flux_plot[0].circle(x=0, y=0, size=1, alpha=0)
        self.js_dummy.glyph.js_on_change("size", callback)

        # Full layout
        self.plots = row(
            *[
                column(m, f, sizing_mode="scale_both")
                for m, f in zip(self.moll_plot, self.flux_plot)
            ],
            margin=(10, 30, 10, 30),
            sizing_mode="scale_both",
            css_classes=["samples"],
        )
        self.layout = grid([[self.plots]])
Exemple #20
0
def main():
    # Parse command line arguments
    parser = argparse.ArgumentParser(
        description='Read an mp3 file and plot out its pitch')
    parser.add_argument('-i',
                        dest='input',
                        type=str,
                        help='Input file path',
                        default='')
    parser.add_argument('-n',
                        dest='normalize',
                        action='store_true',
                        help='Normalize input values',
                        default='')
    parser.add_argument(
        '-t',
        dest='transform',
        type=str,
        help='Use different transforms on the input audio signal',
        default='stft')
    parser.add_argument(
        '-s',
        dest='sample',
        type=int,
        help=
        'Sampling rate in Hz to use on the input audio signal while transforming',
        default='100')
    parser.add_argument(
        '-w',
        dest='window',
        type=float,
        help='Sampling window in s to use on the input audio signal for stft',
        default='0.05')
    options = parser.parse_args()

    # Error check
    if options.input == '':
        print("No input given. BYE!\n")
        return 1
    elif not os.path.isfile(options.input):
        print(f"Given input path {options.input} does not exist!")
        return 2

    # Read input file into frame rate and data
    try:
        inSignal = MP3.read(options.input, options.normalize)
    except:
        print("Reading MP3 failed")
        return 3

    figures = []
    # Plot the data for quick visualization
    if options.transform == 'none':
        for i in range(0, inSignal.channels):
            if i == 0:
                figures.append(
                    bkfigure(plot_width=1200,
                             plot_height=600,
                             x_axis_label='Time',
                             y_axis_label='Amp'))
            else:
                figures.append(
                    bkfigure(plot_width=1200,
                             plot_height=600,
                             x_axis_label='Time',
                             y_axis_label='Amp',
                             x_range=figures[0].x_range,
                             y_range=figures[0].y_range))
            figures[i].line(inSignal.time, inSignal.audioData[:, i])
    elif options.transform == 'stft':
        # STFT over the signal
        fSignal = Transforms.STFT(inSignal,
                                  windowEvery=1 / options.sample,
                                  windowLength=options.window)
        for i in range(0, inSignal.channels):
            if i == 0:
                figures.append(
                    bkfigure(plot_width=1200,
                             plot_height=400,
                             x_axis_label='Time',
                             y_axis_label='Frequency'))
            else:
                figures.append(
                    bkfigure(plot_width=1200,
                             plot_height=400,
                             x_axis_label='Time',
                             y_axis_label='Frequency',
                             x_range=figures[0].x_range,
                             y_range=figures[0].y_range))
            channelAmp = np.max(fSignal.audioData[:, :, i])
            figures[i].image(image=[fSignal.audioData[:, :, i]],
                             x=0,
                             y=0,
                             dw=fSignal.time[-1],
                             dh=fSignal.dimensionAxes[0][-1],
                             color_mapper=LinearColorMapper(high=channelAmp,
                                                            low=0,
                                                            palette=Inferno11))
    else:
        print("Unrecognized transform given!")
        return 4

    bkshow(bkcolumn(*figures))
    return 0
Exemple #21
0
def main():
    import argparse
    from load_config import load_config

    parser = argparse.ArgumentParser()
    parser.add_argument('pickle_dump')
    parser.add_argument('output_file')
    parser.add_argument('-c',
                        '--config_filename',
                        dest='config_filename',
                        help="Config file with private info",
                        default=None)
    parser.add_argument(
        '--max_happy_commute',
        default=45,
        type=float,
        help=
        "For plot that overlays all the commutes, what's the longest not colored red?"
    )
    parser.add_argument('--center_lat',
                        default=34.053695,
                        help="latitude to center on")
    parser.add_argument('--center_lng',
                        default=-118.430208,
                        help="longitutde to center on")
    parser.add_argument(
        '--zoom',
        default=11,
        type=int,
        help="initial zoom of maps.  goes 1 (least zoomed) to 20 (most zoomed)"
    )
    parser.add_argument(
        '--map_type',
        default='roadmap',
        help="initial zoom of maps.  goes 1 (least zoomed) to 20 (most zoomed)"
    )
    parser.add_argument(
        '--palette',
        default='Viridis',
        help="Palette to use.  Must be in bokeh.palettes.all_palettes")
    parser.add_argument(
        '--ncolors',
        type=int,
        default=256,
        help=
        "Number of colors to use.  Must be able to access bokeh.palettes.all_palettes[<palette>][<ncolors>]"
    )
    parser.add_argument('--cbar_min', default=15, type=float)
    parser.add_argument('--cbar_max', default=75, type=float)

    args = parser.parse_args()

    config, timezome = load_config(args.config_filename)
    api_key = config['api_key']

    with open(args.pickle_dump, 'rb') as f:
        print(f"Loading from {args.pickle_dump}")
        data = pickle.load(f)

    center_lats = np.array(data.pop('lat'))
    center_longs = np.array(data.pop('long'))
    names = set(k.split('_')[0] for k in data)

    dx = np.max(center_longs[1:] - center_longs[:-1]) / 2
    dy = np.max(center_lats[1:] - center_lats[:-1]) / 2

    xcoords = [[xc - dx, xc - dx, xc + dx, xc + dx] for xc in center_longs]
    ycoords = [[yc - dy, yc + dy, yc + dy, yc - dy] for yc in center_lats]
    print(f"Plotting {len(xcoords)} squares")

    try:
        args.center_lat = float(args.center_lat)
    except ValueError:
        args.center_lat = (min([yc[0] for yc in ycoords]) +
                           max([yc[1] for yc in ycoords])) / 2

    try:
        args.center_lng = float(args.center_lng)
    except ValueError:
        args.center_lng = (min([xc[0] for xc in ycoords]) +
                           max([xc[2] for xc in ycoords])) / 2

    plots = []
    bk.output_file(args.output_file, title="Commute times"),  #mode="inlne")
    moptions = GMapOptions(lat=args.center_lat,
                           lng=args.center_lng,
                           zoom=args.zoom,
                           map_type=args.map_type)

    allkeys = [
        f'{name}_{destkey}'
        for name, destkey in product(names, ['towork', 'tohome'])
    ]
    dsets = {restructure_key(key): data[key] for key in allkeys}

    color_mapper = LinearColorMapper(
        palette=all_palettes[args.palette][args.ncolors],
        low=args.cbar_min,
        high=args.cbar_max)

    nhappy = np.zeros(len(xcoords), dtype=int)
    for name in names:
        for destkey in ['towork', 'tohome']:
            key = f'{name}_{destkey}'
            colors = data[key]
            plots.append(
                plot_patches_on_gmap(xcoords,
                                     ycoords,
                                     api_key,
                                     values=colors,
                                     map_options=moptions,
                                     title=restructure_key(key),
                                     color_mapper=color_mapper))

            msk = np.array(colors) <= args.max_happy_commute
            nhappy[msk] += 1

    ## now overlap all the commutes:
    title = f'Areas where all commutes are < {args.max_happy_commute} minutes'
    plot = plot_patches_on_gmap(
        list(np.array(xcoords)[nhappy < len(allkeys) - 1]),
        list(np.array(ycoords)[nhappy < len(allkeys) - 1]),
        api_key,
        map_options=moptions,
        title=title,
        solid_fill='red')

    data = dict(xs=list(np.array(xcoords)[nhappy == len(allkeys) - 1]),
                ys=list(np.array(ycoords)[nhappy == len(allkeys) - 1]))

    source_patches = bk.ColumnDataSource(data=data)
    patches_glyph = plot.patches('xs',
                                 'ys',
                                 fill_alpha=0.25,
                                 fill_color='orange',
                                 source=source_patches,
                                 line_width=0)

    plots.append(plot)

    ## now show
    grid = gridplot(plots, ncols=2)
    bk.show(grid)

# Saving the locations list to be passed to the dropdown menu
location_list = list(loc_group.groups.keys())

cooccur_mat, person_list = heatmap(cc_data)
heat_df = pd.DataFrame(cooccur_mat)
heat_df.columns=person_list
heat_df.index=person_list
heat_df.columns.name='colnames'
heat_df.index.name='rownames'
xLabels = list(heat_df.index)
yLabels = list(heat_df.columns)
heat_plot_df = pd.DataFrame(heat_df.stack(dropna=False),columns=["colocation"]).reset_index()
colors = ["#75968f", "#a5bab7", "#c9d9d3", "#e2e2e2", "#dfccce", "#ddb7b1", "#cc7878", "#933b41", "#550b1d"]
mapper = LinearColorMapper(palette=colors, low=heat_plot_df.colocation.min(), high=heat_plot_df.colocation.max())


heatmap_source = ColumnDataSource(heat_plot_df)
TOOLS = "hover,save,pan,box_zoom,reset,wheel_zoom"

hp = figure(title="Colocation heatmap",
           x_range=xLabels, y_range=yLabels,
           x_axis_location="above",y_axis_location="right", plot_width=800, plot_height=600,
           tools=TOOLS, toolbar_location='below')

hp.grid.grid_line_color = None
hp.axis.axis_line_color = None
hp.axis.major_tick_line_color = None
hp.axis.major_label_text_font_size = "5pt"
hp.axis.major_label_standoff = 0
Exemple #23
0
    def tool_handler_2d(self, doc):
        from bokeh import events
        from bokeh.layouts import row, column, widgetbox, Spacer
        from bokeh.models import ColumnDataSource, widgets
        from bokeh.models.mappers import LinearColorMapper
        from bokeh.models.widgets.markups import Div
        from bokeh.plotting import figure

        arr = self.arr
        # Set up the data
        x_coords, y_coords = arr.coords[arr.dims[0]], arr.coords[arr.dims[1]]

        # Styling
        default_palette = self.default_palette
        if arr.S.is_subtracted:
            default_palette = cc.coolwarm

        error_alpha = 0.3
        error_fill = '#3288bd'

        # Application Organization
        self.app_context.update({
            'data': arr,
            'data_range': {
                'x': (np.min(x_coords.values), np.max(x_coords.values)),
                'y': (np.min(y_coords.values), np.max(y_coords.values)),
            },
            'show_stat_variation': False,
            'color_mode': 'linear',
        })

        def stats_patch_from_data(data, subsampling_rate=None):
            if subsampling_rate is None:
                subsampling_rate = int(min(data.values.shape[0] / 50, 5))
                if subsampling_rate == 0:
                    subsampling_rate = 1

            x_values = data.coords[data.dims[0]].values[::subsampling_rate]
            values = data.values[::subsampling_rate]
            sq = np.sqrt(values)
            lower, upper = values - sq, values + sq

            return {
                'x': np.append(x_values, x_values[::-1]),
                'y': np.append(lower, upper[::-1]),
            }

        def update_stat_variation(plot_name, data):
            patch_data = stats_patch_from_data(data)
            if plot_name != 'right':  # the right plot is on transposed axes
                plots[plot_name +
                      '_marginal_err'].data_source.data = patch_data
            else:
                plots[plot_name + '_marginal_err'].data_source.data = {
                    'x': patch_data['y'],
                    'y': patch_data['x'],
                }

        figures, plots, app_widgets = self.app_context[
            'figures'], self.app_context['plots'], self.app_context['widgets']

        if self.cursor_default is not None and len(self.cursor_default) == 2:
            self.cursor = self.cursor_default
        else:
            self.cursor = [
                np.mean(self.app_context['data_range']['x']),
                np.mean(self.app_context['data_range']['y'])
            ]  # try a sensible default

        # create the main inset plot
        main_image = arr
        prepped_main_image = self.prep_image(main_image)
        self.app_context['color_maps']['main'] = LinearColorMapper(
            default_palette,
            low=np.min(prepped_main_image),
            high=np.max(prepped_main_image),
            nan_color='black')

        main_tools = ["wheel_zoom", "tap", "reset", "save"]
        main_title = 'Bokeh Tool: WARNING Unidentified'
        try:
            main_title = "Bokeh Tool: %s" % arr.S.label[:60]
        except:
            pass
        figures['main'] = figure(tools=main_tools,
                                 plot_width=self.app_main_size,
                                 plot_height=self.app_main_size,
                                 min_border=10,
                                 min_border_left=50,
                                 toolbar_location='left',
                                 x_axis_location='below',
                                 y_axis_location='right',
                                 title=main_title,
                                 x_range=self.app_context['data_range']['x'],
                                 y_range=self.app_context['data_range']['y'])
        figures['main'].xaxis.axis_label = arr.dims[0]
        figures['main'].yaxis.axis_label = arr.dims[1]
        figures['main'].toolbar.logo = None
        figures['main'].background_fill_color = "#fafafa"
        plots['main'] = figures['main'].image(
            [prepped_main_image.T],
            x=self.app_context['data_range']['x'][0],
            y=self.app_context['data_range']['y'][0],
            dw=self.app_context['data_range']['x'][1] -
            self.app_context['data_range']['x'][0],
            dh=self.app_context['data_range']['y'][1] -
            self.app_context['data_range']['y'][0],
            color_mapper=self.app_context['color_maps']['main'])

        app_widgets['info_div'] = Div(text='',
                                      width=self.app_marginal_size,
                                      height=100)

        # Create the bottom marginal plot
        bottom_marginal = arr.sel(**dict([[arr.dims[1], self.cursor[1]]]),
                                  method='nearest')
        figures['bottom_marginal'] = figure(
            plot_width=self.app_main_size,
            plot_height=200,
            title=None,
            x_range=figures['main'].x_range,
            y_range=(np.min(bottom_marginal.values),
                     np.max(bottom_marginal.values)),
            x_axis_location='above',
            toolbar_location=None,
            tools=[])
        plots['bottom_marginal'] = figures['bottom_marginal'].line(
            x=bottom_marginal.coords[arr.dims[0]].values,
            y=bottom_marginal.values)
        plots['bottom_marginal_err'] = figures['bottom_marginal'].patch(
            x=[],
            y=[],
            color=error_fill,
            fill_alpha=error_alpha,
            line_color=None)

        # Create the right marginal plot
        right_marginal = arr.sel(**dict([[arr.dims[0], self.cursor[0]]]),
                                 method='nearest')
        figures['right_marginal'] = figure(
            plot_width=200,
            plot_height=self.app_main_size,
            title=None,
            y_range=figures['main'].y_range,
            x_range=(np.min(right_marginal.values),
                     np.max(right_marginal.values)),
            y_axis_location='left',
            toolbar_location=None,
            tools=[])
        plots['right_marginal'] = figures['right_marginal'].line(
            y=right_marginal.coords[arr.dims[1]].values,
            x=right_marginal.values)
        plots['right_marginal_err'] = figures['right_marginal'].patch(
            x=[],
            y=[],
            color=error_fill,
            fill_alpha=error_alpha,
            line_color=None)

        cursor_lines = self.add_cursor_lines(figures['main'])

        # Attach tools and callbacks
        toggle = widgets.Toggle(label="Show Stat. Variation",
                                button_type="success",
                                active=False)

        def set_show_stat_variation(should_show):
            self.app_context['show_stat_variation'] = should_show

            if should_show:
                main_image_data = arr
                update_stat_variation(
                    'bottom',
                    main_image_data.sel(**dict([[arr.dims[1],
                                                 self.cursor[1]]]),
                                        method='nearest'))
                update_stat_variation(
                    'right',
                    main_image_data.sel(**dict([[arr.dims[0],
                                                 self.cursor[0]]]),
                                        method='nearest'))
                plots['bottom_marginal_err'].visible = True
                plots['right_marginal_err'].visible = True
            else:
                plots['bottom_marginal_err'].visible = False
                plots['right_marginal_err'].visible = False

        toggle.on_click(set_show_stat_variation)

        scan_keys = [
            'x', 'y', 'z', 'pass_energy', 'hv', 'location', 'id', 'probe_pol',
            'pump_pol'
        ]
        scan_info_source = ColumnDataSource({
            'keys': [k for k in scan_keys if k in arr.attrs],
            'values': [
                str(v) if isinstance(v, float) and np.isnan(v) else v
                for v in [arr.attrs[k] for k in scan_keys if k in arr.attrs]
            ],
        })
        scan_info_columns = [
            widgets.TableColumn(field='keys', title='Attr.'),
            widgets.TableColumn(field='values', title='Value'),
        ]

        POINTER_MODES = [
            (
                'Cursor',
                'cursor',
            ),
            (
                'Path',
                'path',
            ),
        ]

        COLOR_MODES = [
            (
                'Adaptive Hist. Eq. (Slow)',
                'adaptive_equalization',
            ),
            # ('Histogram Eq.', 'equalization',), # not implemented
            (
                'Linear',
                'linear',
            ),
            # ('Log', 'log',), # not implemented
        ]

        def on_change_color_mode(attr, old, new_color_mode):
            self.app_context['color_mode'] = new_color_mode
            if old is None or old != new_color_mode:
                right_image_data = arr.sel(**dict(
                    [[arr.dims[0], self.cursor[0]]]),
                                           method='nearest')
                bottom_image_data = arr.sel(**dict(
                    [[arr.dims[1], self.cursor[1]]]),
                                            method='nearest')
                main_image_data = arr
                prepped_right_image = self.prep_image(right_image_data)
                prepped_bottom_image = self.prep_image(bottom_image_data)
                prepped_main_image = self.prep_image(main_image_data)
                plots['right'].data_source.data = {
                    'image': [prepped_right_image]
                }
                plots['bottom'].data_source.data = {
                    'image': [prepped_bottom_image.T]
                }
                plots['main'].data_source.data = {
                    'image': [prepped_main_image.T]
                }
                update_main_colormap(None, None, main_color_range_slider.value)

        color_mode_dropdown = widgets.Dropdown(label='Color Mode',
                                               button_type='primary',
                                               menu=COLOR_MODES)
        color_mode_dropdown.on_change('value', on_change_color_mode)

        symmetry_point_name_input = widgets.TextInput(
            title='Symmetry Point Name', value="G")
        snap_checkbox = widgets.CheckboxButtonGroup(labels=['Snap Axes'],
                                                    active=[])
        place_symmetry_point_at_cursor_button = widgets.Button(
            label="Place Point", button_type="primary")

        def update_symmetry_points_for_display():
            pass

        def place_symmetry_point():
            cursor_dict = dict(zip(arr.dims, self.cursor))
            skip_dimensions = {'eV', 'delay', 'cycle'}
            if 'symmetry_points' not in arr.attrs:
                arr.attrs['symmetry_points'] = {}

            snap_distance = {
                'phi': 2,
                'beta': 2,
                'kx': 0.01,
                'ky': 0.01,
                'kz': 0.01,
                'kp': 0.01,
                'hv': 4,
            }

            cursor_dict = {
                k: v
                for k, v in cursor_dict.items() if k not in skip_dimensions
            }
            snapped = copy.copy(cursor_dict)

            if 'Snap Axes' in [
                    snap_checkbox.labels[i] for i in snap_checkbox.active
            ]:
                for axis, value in cursor_dict.items():
                    options = [
                        point[axis]
                        for point in arr.attrs['symmetry_points'].values()
                        if axis in point
                    ]
                    options = sorted(options, key=lambda x: np.abs(x - value))
                    if options and np.abs(options[0] -
                                          value) < snap_distance[axis]:
                        snapped[axis] = options[0]

            arr.attrs['symmetry_points'][
                symmetry_point_name_input.value] = snapped

        place_symmetry_point_at_cursor_button.on_click(place_symmetry_point)

        main_color_range_slider = widgets.RangeSlider(
            start=0, end=100, value=(
                0,
                100,
            ), title='Color Range (Main)')

        layout = row(
            column(figures['main'], figures['bottom_marginal']),
            column(figures['right_marginal'], Spacer(width=200, height=200)),
            column(
                widgetbox(
                    widgets.Dropdown(label='Pointer Mode',
                                     button_type='primary',
                                     menu=POINTER_MODES)),
                widgets.Tabs(tabs=[
                    widgets.Panel(child=widgetbox(
                        Div(text='<h2>Colorscale:</h2>'),
                        color_mode_dropdown,
                        main_color_range_slider,
                        Div(text=
                            '<h2 style="padding-top: 30px;">General Settings:</h2>'
                            ),
                        toggle,
                        self._cursor_info,
                        sizing_mode='scale_width'),
                                  title='Settings'),
                    widgets.Panel(child=widgetbox(
                        app_widgets['info_div'],
                        Div(text=
                            '<h2 style="padding-top: 30px; padding-bottom: 10px;">Scan Info</h2>'
                            ),
                        widgets.DataTable(source=scan_info_source,
                                          columns=scan_info_columns,
                                          width=400,
                                          height=400),
                        sizing_mode='scale_width',
                        width=400),
                                  title='Info'),
                    widgets.Panel(child=widgetbox(
                        Div(text='<h2>Preparation</h2>'),
                        symmetry_point_name_input,
                        snap_checkbox,
                        place_symmetry_point_at_cursor_button,
                        sizing_mode='scale_width'),
                                  title='Preparation'),
                ],
                             width=400)))

        update_main_colormap = self.update_colormap_for('main')

        def on_click_save(event):
            save_dataset(arr)
            print(event)

        def click_main_image(event):
            self.cursor = [event.x, event.y]

            right_marginal_data = arr.sel(**dict(
                [[arr.dims[0], self.cursor[0]]]),
                                          method='nearest')
            bottom_marginal_data = arr.sel(**dict(
                [[arr.dims[1], self.cursor[1]]]),
                                           method='nearest')
            plots['bottom_marginal'].data_source.data = {
                'x': bottom_marginal_data.coords[arr.dims[0]].values,
                'y': bottom_marginal_data.values,
            }
            plots['right_marginal'].data_source.data = {
                'y': right_marginal_data.coords[arr.dims[1]].values,
                'x': right_marginal_data.values,
            }
            if self.app_context['show_stat_variation']:
                update_stat_variation('right', right_marginal_data)
                update_stat_variation('bottom', bottom_marginal_data)
            figures['bottom_marginal'].y_range.start = np.min(
                bottom_marginal_data.values)
            figures['bottom_marginal'].y_range.end = np.max(
                bottom_marginal_data.values)
            figures['right_marginal'].x_range.start = np.min(
                right_marginal_data.values)
            figures['right_marginal'].x_range.end = np.max(
                right_marginal_data.values)

            self.save_app()

        figures['main'].on_event(events.Tap, click_main_image)
        main_color_range_slider.on_change('value', update_main_colormap)

        doc.add_root(layout)
        doc.title = "Bokeh Tool"
        self.load_app()
        self.save_app()
p = figure(plot_width=800,
           plot_height=800,
           match_aspect=True,
           tools=['pan', 'box_zoom', 'reset'],
           title='',
           sizing_mode='scale_height',
           output_backend="webgl")
cr = p.circle(x='tsne_x', y='tsne_y', color=cls_color_mapper, source=source)
cr.selection_glyph = Circle(fill_color=cls_color_mapper,
                            line_color=cls_color_mapper)
cr.nonselection_glyph = Circle(fill_color=cls_color_mapper,
                               line_color=cls_color_mapper,
                               fill_alpha=0.05,
                               line_alpha=0.05)
color_bar = ColorBar(color_mapper=LinearColorMapper(palette="Viridis256",
                                                    low=1,
                                                    high=10),
                     label_standoff=12,
                     border_line_color=None,
                     location=(0, 0))
p.add_layout(color_bar, 'right')
color_bar.visible = False
if type(cls_color_mapper['transform']) is LinearColorMapper:
    color_bar.color_mapper = cls_color_mapper['transform']
    # p.add_layout(color_bar, 'right')
    color_bar.visible = True
# Define widgets
hover_tip_tool = HoverTool(tooltips=generate_tooltip_html(),
                           show_arrow=False,
                           renderers=[cr])
wheel_zoom_tool = WheelZoomTool()
Exemple #25
0
    def plot_ortho(self):
        # Plot the map
        plot_ortho = figure(
            aspect_ratio=2,
            toolbar_location=None,
            x_range=(-2, 2),
            y_range=(-1, 1),
            id="plot_ortho",
            name="plot_ortho",
            min_border_left=0,
            min_border_right=0,
            css_classes=["plot_ortho_{:d}".format(self.counter)],
        )
        plot_ortho.axis.visible = False
        plot_ortho.grid.visible = False
        plot_ortho.outline_line_color = None
        color_mapper = LinearColorMapper(
            palette="Plasma256",
            nan_color="white",
            low=self.vmin_o,
            high=self.vmax_o,
        )
        plot_ortho.image(
            image="ortho",
            x=-1,
            y=-1,
            dw=2,
            dh=2,
            color_mapper=color_mapper,
            source=self.source_ortho,
        )
        plot_ortho.toolbar.active_drag = None
        plot_ortho.toolbar.active_scroll = None
        plot_ortho.toolbar.active_tap = None

        # Plot the lat/lon grid
        lat_lines = get_ortho_latitude_lines(inc=self.inc)
        for x, y in lat_lines:
            plot_ortho.line(x, y, line_width=1, color="black", alpha=0.25)
        for i in range(len(self.lon_x[0])):
            plot_ortho.line(
                "x{:d}".format(i),
                "y{:d}".format(i),
                line_width=1,
                color="black",
                alpha=0.25,
                source=self.source_ortho_lon,
            )
        self.add_border(plot_ortho, "ortho")

        # Interaction: Rotate the star as the mouse wheel moves
        mouse_wheel_callback = CustomJS(
            args={
                "source_ortho": self.source_ortho,
                "source_index": self.source_index,
                "source_flux": self.source_flux,
                "source_ortho_lon": self.source_ortho_lon,
                "lon_x": self.lon_x,
                "lon_y": self.lon_y,
                "nlon": len(self.lon_x[0]),
                "ortho": self.ortho,
                "flux0": self.flux0,
                "flux": self.flux,
                "npix_o": self.npix_o,
                "nt": self.nt,
                "speed": self.nt / 200,
            },
            code="""
                // Update the current theta index
                var delta = cb_obj["delta"];
                var t = source_index.data["t"][0];
                t += delta * speed;
                while (t < 0) t += nt;
                while (t > nt - 1) t -= nt;
                source_index.data["t"][0] = t;
                source_index.change.emit();
                var tidx = Math.floor(t);
                while (tidx < 0) tidx += nt;
                while (tidx > nt - 1) tidx -= nt;

                // Update the map
                source_ortho.data["ortho"][0] = ortho[tidx];
                source_ortho.change.emit();

                // Update the longitude lines
                var k;
                for (var k = 0; k < nlon; k++) {
                    source_ortho_lon.data["x" + k] = lon_x[tidx][k];
                    source_ortho_lon.data["y" + k] = lon_y[tidx][k];
                }
                source_ortho_lon.change.emit();

                // Update the flux
                source_flux.data["flux"] = flux[tidx];
                source_flux.data["flux0"] = flux0[tidx];
                source_flux.change.emit();
                """,
        )
        plot_ortho.js_on_event(MouseWheel, mouse_wheel_callback)

        mouse_enter_callback = CustomJS(code="""
            DISABLE_WHEEL = true;
            """)
        plot_ortho.js_on_event(MouseEnter, mouse_enter_callback)

        mouse_leave_callback = CustomJS(code="""
            DISABLE_WHEEL = false;
            """)
        plot_ortho.js_on_event(MouseLeave, mouse_leave_callback)

        return plot_ortho
    def tool_handler(self, doc):
        from bokeh.layouts import row, column, widgetbox
        from bokeh.models import widgets
        from bokeh.models.mappers import LinearColorMapper
        from bokeh.plotting import figure

        default_palette = self.default_palette
        difference_palette = cc.coolwarm

        intensity_slider = widgets.Slider(
            title='Relative Intensity Scaling', start=0.5, end=1.5,
            step=0.005, value=1)

        self.app_context.update({
            'A': self.arr,
            'B': self.other,
            'compared': self.compared,
            'plots': {},
            'figures': {},
            'widgets': {},
            'data_range': self.arr.T.range(),
            'color_maps': {},
        })

        self.color_maps['main'] = LinearColorMapper(
            default_palette, low=np.min(self.arr.values), high=np.max(self.arr.values), nan_color='black')

        figure_kwargs = {
            'tools': ['reset', 'wheel_zoom'],
            'plot_width': self.app_main_size,
            'plot_height': self.app_main_size,
            'min_border': 10,
            'toolbar_location': 'left',
            'x_range': self.data_range['x'],
            'y_range': self.data_range['y'],
            'x_axis_location': 'below',
            'y_axis_location': 'right',
        }

        self.figures['A'] = figure(title='Spectrum A', **figure_kwargs)
        self.figures['B'] = figure(title='Spectrum B', **figure_kwargs)
        self.figures['compared'] = figure(title='Comparison', **figure_kwargs)

        self.compared = self.arr - self.other
        diff_low, diff_high = np.min(self.arr.values), np.max(self.arr.values)
        diff_range = np.sqrt((abs(diff_low) + 1) * (abs(diff_high) + 1)) * 1.5
        self.color_maps['difference'] = LinearColorMapper(
            difference_palette, low=-diff_range, high=diff_range, nan_color='white')

        self.plots['A'] = self.figures['A'].image(
            [self.arr.values], x=self.data_range['x'][0], y=self.data_range['y'][0],
            dw=self.data_range['x'][1] - self.data_range['x'][0],
            dh=self.data_range['y'][1] - self.data_range['y'][0],
            color_mapper=self.color_maps['main'],
        )

        self.plots['B'] = self.figures['B'].image(
            [self.other.values], x=self.data_range['x'][0], y=self.data_range['y'][0],
            dw=self.data_range['x'][1] - self.data_range['x'][0],
            dh=self.data_range['y'][1] - self.data_range['y'][0],
            color_mapper=self.color_maps['main']
        )

        self.plots['compared'] = self.figures['compared'].image(
            [self.compared.values], x=self.data_range['x'][0], y=self.data_range['y'][0],
            dw=self.data_range['x'][1] - self.data_range['x'][0],
            dh=self.data_range['y'][1] - self.data_range['y'][0],
            color_mapper=self.color_maps['difference']
        )

        x_axis_name = self.arr.dims[0]
        y_axis_name = self.arr.dims[1]

        stride = self.arr.T.stride()
        delta_x_axis = stride['x']
        delta_y_axis = stride['y']

        delta_x_slider = widgets.Slider(
            title='{} Shift'.format(x_axis_name), start=-20 * delta_x_axis,
            step=delta_x_axis / 2, end=20 * delta_x_axis, value=0)

        delta_y_slider = widgets.Slider(
            title='{} Shift'.format(y_axis_name), start=-20 * delta_y_axis,
            step=delta_y_axis / 2, end=20 * delta_y_axis, value=0)

        @Debounce(0.5)
        def update_summed_figure(attr, old, new):
            # we don't actually use the args because we need to pull all the data out
            shifted = (intensity_slider.value) * scipy.ndimage.interpolation.shift(self.other.values, [
                delta_x_slider.value / delta_x_axis,
                delta_y_slider.value / delta_y_axis,
            ], order=1, prefilter=False, cval=np.nan)
            self.compared = self.arr - xr.DataArray(
                shifted,
                coords=self.arr.coords,
                dims=self.arr.dims)

            self.compared.attrs.update(**self.arr.attrs)
            try:
                del self.compared.attrs['id']
            except KeyError:
                pass

            self.app_context['compared'] = self.compared
            self.plots['compared'].data_source.data = {
                'image': [self.compared.values]
            }

        layout = column(
            row(
                column(self.app_context['figures']['A']),
                column(self.app_context['figures']['B']),
            ),
            row(
                column(self.app_context['figures']['compared']),
                widgetbox(
                    intensity_slider,
                    delta_x_slider,
                    delta_y_slider,
                ),
            )
        )

        update_summed_figure(None, None, None)

        delta_x_slider.on_change('value', update_summed_figure)
        delta_y_slider.on_change('value', update_summed_figure)
        intensity_slider.on_change('value', update_summed_figure)

        doc.add_root(layout)
        doc.title = 'Comparison Tool'
Exemple #27
0
colors = YlOrRd8[0:5][::-1]
sources = {}

for i in range(len(colors)):
    mask = df[(df['FF_PCT'] > bins[i]) & (df['FF_PCT'] < bins[i + 1])]
    cds = ColumnDataSource(mask)
    sources[i] = cds
    fig.circle('x',
               'y',
               line_color=None,
               fill_color=colors[i],
               size=5,
               source=sources[i])

# ColorBar Legend
color_mapper = LinearColorMapper(palette=colors, low=0, high=100)
color_bar = ColorBar(color_mapper=color_mapper,
                     ticker=FixedTicker(ticks=[0, 20, 40, 60, 80, 100]),
                     label_standoff=12,
                     border_line_color=None,
                     location=(0, 0),
                     title='Percentile')
fig.add_layout(color_bar, 'right')


## Callback function for widgets
# Change data source (i.e. which DataFrame to use)
def callback_change_dataframe(new):
    # New datasources
    global df_dict, card_type_dict
    card_type = card_type_dict[metrocard_type_buttons.active]  # Get card type
                      hover_fill_color = color, hover_alpha=.5, hover_line_color='white')
w = Whisker(source=source_df_20_fok_nov,
            base="fok_20", upper="upper", lower="lower", level="overlay", line_color=color)
w.upper_head.line_color = color
w.lower_head.line_color = color
fig_2_a.add_layout(w)
fig_2_a.add_tools(HoverTool(renderers=[plot], tooltips=[('Fok quintile','@fok_20'),('Avg.','@log_novelty')], mode='vline'))

# Figure 2_b - : scatter plot de bt
TOOLS = "hover,save,pan,box_zoom,reset,wheel_zoom"
var = ['log_cit_10','bt','fail']
varnames = ['Fwd cites 10y (log)','Breakthrough rate','Failure rate']
colors = Viridis256
mappers = []
for i in var:
    mappers.append(LinearColorMapper(palette=colors, low=df_20[i].min(), high=df_20[i].max()))

source_df_20 = ColumnDataSource(df_20)
scatters = []
plot_size = round(1000/3)
for i in range(3):
    scatters.append(figure(title='{}'.format(varnames[i]),x_axis_label = 'Quintiles of fok',
                          plot_width=plot_size, plot_height=plot_size,tools=TOOLS, toolbar_location='below',
                          tooltips=[('fok:', '@fok_20'),('novelty:', '@novelty_20'), ('Avg. {}'.format(varnames[i]), '@{}'.format(var[i]))]))    
    scatters[i].grid.grid_line_color = None
    scatters[i].axis.axis_line_color = None
    scatters[i].axis.major_tick_line_color = None
    scatters[i].axis.major_label_text_font_size = "7px"
    scatters[i].axis.major_label_standoff = 0
    scatters[i].xaxis.major_label_orientation = math.pi / 3
    
Exemple #29
0
        def update_plots(new):

            print("Starting update")

            nonlocal Estimators

            if not isinstance(Estimators, (type(np.array), list)):
                Estimators = np.array(Estimators)

            estimator_names = np.array(list(estimator_select.value))
            ix = np.isin(Estimator_Names, estimator_names)
            estimator_indices = [int(i) for i in np.where(ix)[0].flatten()]

            estimators = np.array(Estimators)[estimator_indices]

            variable1 = drop1.value
            variable2 = drop2.value
            y = drop3.value

            #Things to update:
            # image background i.e. image source √
            # observation source √
            #Color mapper values√
            #hover tool values √
            #Figure ranges √
            #Model score text things √

            #Lets calculate all the image and observation data first

            plots = [None for i in range(len(estimators))]
            image_sources = [None for i in range(len(estimators))]
            observation_sources = [None for i in range(len(estimators))]
            hover_tools = [None for i in range(len(estimators))]
            model_score_sources = [None for i in range(len(estimators))]
            glyphs0 = [None for i in range(len(estimators))]
            color_bars = [None for i in range(len(estimators))]
            p_circles = [None for i in range(len(estimators))]
            p_images = [None for i in range(len(estimators))]

            #Iterate over the estimators
            for idx, estimator in enumerate(estimators):
                #Find the title for each plot
                estimator_name = str(estimator()).split('(')[0]

                #Extract the needed data
                full_mat = X[[variable1, variable2, y]].dropna(how="any",
                                                               axis=0)

                #Define a class bijection for class colour mapping
                unique_classes, y_bijection = np.unique(full_mat[y],
                                                        return_inverse=True)
                full_mat['y_bijection'] = y_bijection

                #Rescale the X Data so that the data fits nicely on the axis/predictions are reliable
                full_mat[variable1 + "_s"] = StandardScaler().fit_transform(
                    full_mat[variable1].values.reshape((-1, 1)))
                full_mat[variable2 + "_s"] = StandardScaler().fit_transform(
                    full_mat[variable2].values.reshape((-1, 1)))

                #Define the Step size in the mesh
                delta = Delta

                #Separate the data into arrays so it is easy to work with
                X1 = full_mat[variable1 + "_s"].values
                X2 = full_mat[variable2 + "_s"].values
                Y = full_mat["y_bijection"].values

                #Define the mesh-grid co-ordiantes over which to colour in
                x1_min, x1_max = X1.min() - 0.5, X1.max() + 0.5
                x2_min, x2_max = X2.min() - 0.5, X2.max() + 0.5

                #Create the meshgrid itself
                x1, x2 = np.arange(x1_min, x1_max,
                                   delta), np.arange(x2_min, x2_max, delta)
                x1x1, x2x2 = np.meshgrid(x1, x2)

                #Create the train test split
                X_train, X_test, y_train, y_test = train_test_split(
                    full_mat[[variable1 + "_s", variable2 + "_s"]],
                    Y,
                    test_size=Test_Size,
                    random_state=Random_State)
                #Fit and predict/score the model
                model = estimator().fit(X=X_train, y=y_train)
                # train_preds = model.predict(X_train)
                # test_preds = model.predict(X_test)
                model_score = model.score(X_test, y_test)
                model_score_text = "Model score: %.2f" % model_score

                if hasattr(model, "decision_function"):
                    Z = model.decision_function(np.c_[x1x1.ravel(),
                                                      x2x2.ravel()])

                elif hasattr(model, "predict_proba"):
                    Z = model.predict_proba(np.c_[x1x1.ravel(), x2x2.ravel()])

                else:
                    print(
                        "This Estimator doesn't have a decision_function attribute and can't predict probabilities"
                    )

                Z = np.argmax(Z, axis=1)
                Z_uniques = np.unique(Z)

                unique_predictions = unique_classes[Z_uniques]

                Z = Z.reshape(x1x1.shape)

                #Add in the probabilities and predicitions for the tooltips
                full_mat["probability"] = np.amax(model.predict_proba(
                    full_mat[[variable1 + "_s", variable2 + "_s"]]),
                                                  axis=1)

                bijected_predictions = model.predict(
                    full_mat[[variable1 + "_s", variable2 + "_s"]])
                full_mat["prediction"] = unique_classes[bijected_predictions]

                #Add an associated color to the predictions
                number_of_colors = len(np.unique(y_bijection))

                #Create the hover tool to be updated
                hover = HoverTool(tooltips=[(
                    variable1, "@" +
                    variable1), (variable2, "@" +
                                 variable2), ("Probability", "@probability"),
                                            ("Prediction",
                                             "@prediction"), ("Actual",
                                                              "@" + y)])

                #Create the axes for all the plots
                plots[idx] = figure(x_axis_label=variable1,
                                    y_axis_label=variable2,
                                    title=estimator_name,
                                    x_range=(x1x1.min(), x1x1.max()),
                                    y_range=(x2x2.min(), x2x2.max()),
                                    plot_height=600,
                                    plot_width=600)

                #Create all the image sources
                image_data = dict()
                image_data['x'] = np.array([x1x1.min()])
                image_data["y"] = np.array([x2x2.min()])
                image_data['dw'] = np.array([x1x1.max() - x1x1.min()])
                image_data['dh'] = np.array([x2x2.max() - x2x2.min()])
                image_data['boundaries'] = [Z]

                image_sources[idx] = ColumnDataSource(image_data)

                #Create all the updatable images (boundaries)
                p_images[idx] = plots[idx].image(image='boundaries',
                                                 x='x',
                                                 y='y',
                                                 dw='dw',
                                                 dh='dh',
                                                 palette="RdBu11",
                                                 source=image_sources[idx])

                #Create the sources to update the observation points
                observation_sources[idx] = ColumnDataSource(data=full_mat)

                #Create all the updatable points
                low = full_mat["y_bijection"].min()
                high = full_mat["y_bijection"].max()
                cbar_mapper = LinearColorMapper(palette=RdBu[number_of_colors],
                                                high=high,
                                                low=low)

                p_circles[idx] = plots[idx].circle(
                    x=variable1 + "_s",
                    y=variable2 + "_s",
                    color=dict(field='y_bijection', transform=cbar_mapper),
                    source=observation_sources[idx],
                    line_color="black")

                #Create the hovertool for each plot
                hover_tools[idx] = hover

                #Add the hover tools to each plot
                plots[idx].add_tools(hover_tools[idx])

                #Create all the text sources (model scores) for the plots
                model_score_sources[idx] = ColumnDataSource(
                    data=dict(x=[x1x1.min() + 0.3],
                              y=[x2x2.min() + 0.3],
                              text=[model_score_text]))

                #Add the model scores to all the plots
                score_as_text = Text(x="x", y="y", text="text")
                glyphs0[idx] = plots[idx].add_glyph(model_score_sources[idx],
                                                    score_as_text)

                #Add a colorbar
                color_bars[idx] = ColorBar(
                    color_mapper=cbar_mapper,
                    ticker=BasicTicker(desired_num_ticks=number_of_colors),
                    label_standoff=12,
                    location=(0, 0),
                    bar_line_color="black")

                plots[idx].add_layout(color_bars[idx], "right")
                plots[idx].add_tools(LassoSelectTool(), WheelZoomTool())

                # configure so that no drag tools are active
                plots[idx].toolbar.tools = plots[idx].toolbar.tools[1:]
                plots[idx].toolbar.tools[0], plots[idx].toolbar.tools[
                    -2] = plots[idx].toolbar.tools[-2], plots[
                        idx].toolbar.tools[0]

            layout = gridplot([
                widgetbox(drop1, drop2, drop3, estimator_select, update_drop)
            ], [row(plot) for plot in plots])
            return layout

            #Finished the callback
            print("Ending Update")
            push_notebook(handle=handle0)
Exemple #30
0
    def tool_handler(self, doc):
        from bokeh.layouts import row, column, widgetbox
        from bokeh.models.mappers import LinearColorMapper
        from bokeh.models import widgets
        from bokeh.plotting import figure

        if len(self.arr.shape) != 2:
            raise AnalysisError(
                'Cannot use the band tool on non image-like spectra')

        arr = self.arr
        x_coords, y_coords = arr.coords[arr.dims[0]], arr.coords[arr.dims[1]]

        default_palette = self.default_palette

        self.app_context.update({
            'bands': {},
            'center_float': None,
            'data': arr,
            'data_range': {
                'x': (np.min(x_coords.values), np.max(x_coords.values)),
                'y': (np.min(y_coords.values), np.max(y_coords.values)),
            },
            'direction_normal': True,
            'fit_mode': 'mdc',
        })

        figures, plots, app_widgets = self.app_context['figures'], self.app_context['plots'], \
                                      self.app_context['widgets']
        self.cursor = [
            np.mean(self.data_range['x']),
            np.mean(self.data_range['y'])
        ]

        self.color_maps['main'] = LinearColorMapper(default_palette,
                                                    low=np.min(arr.values),
                                                    high=np.max(arr.values),
                                                    nan_color='black')

        main_tools = ["wheel_zoom", "tap", "reset"]
        main_title = 'Band Tool: WARNING Unidentified'

        try:
            main_title = 'Band Tool: {}'.format(arr.S.label[:60])
        except:
            pass

        figures['main'] = figure(tools=main_tools,
                                 plot_width=self.app_main_size,
                                 plot_height=self.app_main_size,
                                 min_border=10,
                                 min_border_left=50,
                                 toolbar_location='left',
                                 x_axis_location='below',
                                 y_axis_location='right',
                                 title=main_title,
                                 x_range=self.data_range['x'],
                                 y_range=self.data_range['y'])
        figures['main'].xaxis.axis_label = arr.dims[0]
        figures['main'].yaxis.axis_label = arr.dims[1]
        figures['main'].toolbar.logo = None
        figures['main'].background_fill_color = "#fafafa"
        plots['main'] = figures['main'].image(
            [arr.values.T],
            x=self.app_context['data_range']['x'][0],
            y=self.app_context['data_range']['y'][0],
            dw=self.app_context['data_range']['x'][1] -
            self.app_context['data_range']['x'][0],
            dh=self.app_context['data_range']['y'][1] -
            self.app_context['data_range']['y'][0],
            color_mapper=self.app_context['color_maps']['main'])

        # add lines
        self.add_cursor_lines(figures['main'])
        band_lines = figures['main'].multi_line(xs=[],
                                                ys=[],
                                                line_color='white',
                                                line_width=1)

        def append_point_to_band():
            cursor = self.cursor
            if self.active_band in self.app_context['bands']:
                self.app_context['bands'][self.active_band]['points'].append(
                    list(cursor))
                update_band_display()

        def click_main_image(event):
            self.cursor = [event.x, event.y]
            if self.pointer_mode == 'band':
                append_point_to_band()

        update_main_colormap = self.update_colormap_for('main')

        POINTER_MODES = [
            (
                'Cursor',
                'cursor',
            ),
            (
                'Band',
                'band',
            ),
        ]

        FIT_MODES = [
            (
                'EDC',
                'edc',
            ),
            (
                'MDC',
                'mdc',
            ),
        ]

        DIRECTIONS = [
            (
                'From Bottom/Left',
                'forward',
            ),
            ('From Top/Right', 'reverse'),
        ]

        BAND_TYPES = [(
            'Lorentzian',
            'Lorentzian',
        ), (
            'Voigt',
            'Voigt',
        ), (
            'Gaussian',
            'Gaussian',
        )]

        band_classes = {
            'Lorentzian': band.Band,
            'Gaussian': band.BackgroundBand,
            'Voigt': band.VoigtBand,
        }

        self.app_context['band_options'] = []

        def pack_bands():
            packed_bands = {}
            for band_name, band_description in self.app_context['bands'].items(
            ):
                if not band_description['points']:
                    raise AnalysisError('Band {} is empty.'.format(band_name))

                stray = None
                try:
                    stray = float(band_description['center_float'])
                except (KeyError, ValueError, TypeError):
                    try:
                        stray = float(self.app_context['center_float'])
                    except Exception:
                        pass

                packed_bands[band_name] = {
                    'name': band_name,
                    'band': band_classes.get(band_description['type'],
                                             band.Band),
                    'dims': self.arr.dims,
                    'params': {
                        'amplitude': {
                            'min': 0
                        },
                    },
                    'points': band_description['points'],
                }

                if stray is not None:
                    packed_bands[band_name]['params']['stray'] = stray

            return packed_bands

        def fit(override_data=None):
            packed_bands = pack_bands()
            dims = list(self.arr.dims)
            if 'eV' in dims:
                dims.remove('eV')
            angular_direction = dims[0]
            if isinstance(override_data, xr.Dataset):
                override_data = normalize_to_spectrum(override_data)
            return fit_patterned_bands(
                override_data if override_data is not None else self.arr,
                packed_bands,
                fit_direction='eV' if self.app_context['fit_mode'] == 'edc'
                else angular_direction,
                direction_normal=self.app_context['direction_normal'])

        self.app_context['pack_bands'] = pack_bands
        self.app_context['fit'] = fit

        self.pointer_dropdown = widgets.Dropdown(label='Pointer Mode',
                                                 button_type='primary',
                                                 menu=POINTER_MODES)
        self.direction_dropdown = widgets.Dropdown(label='Fit Direction',
                                                   button_type='primary',
                                                   menu=DIRECTIONS)
        self.band_dropdown = widgets.Dropdown(
            label='Active Band',
            button_type='primary',
            menu=self.app_context['band_options'])
        self.fit_mode_dropdown = widgets.Dropdown(label='Mode',
                                                  button_type='primary',
                                                  menu=FIT_MODES)
        self.band_type_dropdown = widgets.Dropdown(label='Band Type',
                                                   button_type='primary',
                                                   menu=BAND_TYPES)

        self.band_name_input = widgets.TextInput(placeholder='Band name...')
        self.center_float_widget = widgets.TextInput(
            placeholder='Center Constraint')
        self.center_float_copy = widgets.Button(label='Copy to all...')
        self.add_band_button = widgets.Button(label='Add Band')

        self.clear_band_button = widgets.Button(label='Clear Band')
        self.remove_band_button = widgets.Button(label='Remove Band')

        self.main_color_range_slider = widgets.RangeSlider(start=0,
                                                           end=100,
                                                           value=(
                                                               0,
                                                               100,
                                                           ),
                                                           title='Color Range')

        def add_band(band_name):
            if band_name not in self.app_context['bands']:
                self.app_context['band_options'].append((
                    band_name,
                    band_name,
                ))
                self.band_dropdown.menu = self.app_context['band_options']
                self.app_context['bands'][band_name] = {
                    'type': 'Lorentzian',
                    'points': [],
                    'name': band_name,
                    'center_float': None,
                }

                if self.active_band is None:
                    self.active_band = band_name

                self.save_app()

        def on_copy_center_float():
            for band_name in self.app_context['bands'].keys():
                self.app_context['bands'][band_name][
                    'center_float'] = self.app_context['center_float']
                self.save_app()

        def on_change_active_band(attr, old_band_id, band_id):
            self.app_context['active_band'] = band_id
            self.active_band = band_id

        def on_change_pointer_mode(attr, old_pointer_mode, pointer_mode):
            self.app_context['pointer_mode'] = pointer_mode
            self.pointer_mode = pointer_mode

        def set_center_float_value(attr, old_value, new_value):
            self.app_context['center_float'] = new_value
            if self.active_band in self.app_context['bands']:
                self.app_context['bands'][
                    self.active_band]['center_float'] = new_value

            self.save_app()

        def set_fit_direction(attr, old_direction, new_direction):
            self.app_context['direction_normal'] = new_direction == 'forward'
            self.save_app()

        def set_fit_mode(attr, old_mode, new_mode):
            self.app_context['fit_mode'] = new_mode
            self.save_app()

        def set_band_type(attr, old_type, new_type):
            if self.active_band in self.app_context['bands']:
                self.app_context['bands'][self.active_band]['type'] = new_type

            self.save_app()

        def update_band_display():
            band_names = self.app_context['bands'].keys()
            band_lines.data_source.data = {
                'xs': [[p[0] for p in self.app_context['bands'][b]['points']]
                       for b in band_names],
                'ys': [[p[1] for p in self.app_context['bands'][b]['points']]
                       for b in band_names],
            }
            self.save_app()

        self.update_band_display = update_band_display

        def on_clear_band():
            if self.active_band in self.app_context['bands']:
                self.app_context['bands'][self.active_band]['points'] = []
                update_band_display()

        def on_remove_band():
            if self.active_band in self.app_context['bands']:
                del self.app_context['bands'][self.active_band]
                new_band_options = [
                    b for b in self.app_context['band_options']
                    if b[0] != self.active_band
                ]
                self.band_dropdown.menu = new_band_options
                self.app_context['band_options'] = new_band_options
                self.active_band = None
                update_band_display()

        # Attach callbacks
        self.main_color_range_slider.on_change('value', update_main_colormap)

        figures['main'].on_event(events.Tap, click_main_image)
        self.band_dropdown.on_change('value', on_change_active_band)
        self.pointer_dropdown.on_change('value', on_change_pointer_mode)
        self.add_band_button.on_click(
            lambda: add_band(self.band_name_input.value))
        self.clear_band_button.on_click(on_clear_band)
        self.remove_band_button.on_click(on_remove_band)
        self.center_float_copy.on_click(on_copy_center_float)
        self.center_float_widget.on_change('value', set_center_float_value)
        self.direction_dropdown.on_change('value', set_fit_direction)
        self.fit_mode_dropdown.on_change('value', set_fit_mode)
        self.band_type_dropdown.on_change('value', set_band_type)

        layout = row(
            column(figures['main']),
            column(
                widgetbox(
                    self.pointer_dropdown,
                    self.band_dropdown,
                    self.fit_mode_dropdown,
                    self.band_type_dropdown,
                    self.direction_dropdown,
                ), row(
                    self.band_name_input,
                    self.add_band_button,
                ), row(
                    self.clear_band_button,
                    self.remove_band_button,
                ), row(self.center_float_widget, self.center_float_copy),
                widgetbox(
                    self._cursor_info,
                    self.main_color_range_slider,
                )))

        doc.add_root(layout)
        doc.title = 'Band Tool'
        self.load_app()
        self.save_app()