Example #1
0
def plotConfusionMatrix(df, width, height):

    from bokeh.palettes import Blues8

    # Had a specific mapper to map color with value
    mapper = LinearColorMapper(palette=Blues8[::-1],
                               low=df.value.min(),
                               high=df.value.max())

    TOOLS = "hover,save,reset"

    # Define a figure
    p = figure(
        plot_width=width,
        plot_height=height,
        #     title="",
        x_range=list(df.Treatment.drop_duplicates()),
        y_range=list(df.Prediction.drop_duplicates()),
        toolbar_location='above',
        tools=TOOLS,
        tooltips=[('Counts', '@value')],
        x_axis_location="below")

    p.xaxis.axis_label = "Prediction"
    p.yaxis.axis_label = "Truth"

    # Create rectangle for heatmap
    p.rect(x="Prediction",
           y="Treatment",
           width=1,
           height=1,
           source=ColumnDataSource(df),
           line_color=None,
           fill_color=transform('value', mapper))

    # Add legend
    color_bar = ColorBar(color_mapper=mapper,
                         location=(0, 0),
                         label_standoff=12,
                         border_line_color=None,
                         ticker=BasicTicker(desired_num_ticks=len(Blues8)))

    color_bar.background_fill_alpha = 0.0

    p.add_layout(color_bar, 'right')

    p.background_fill_alpha = 0.0
    p.border_fill_alpha = 0.0

    return p
Example #2
0
    def bkplot(self,
               x,
               y,
               color='None',
               radii='None',
               ps=20,
               minps=0,
               alpha=0.8,
               pw=600,
               ph=400,
               palette='Inferno256',
               style='smapstyle',
               Hover=True,
               title='',
               table=False,
               table_width=600,
               table_height=150,
               add_colorbar=True,
               Periodic_color=False,
               return_datasrc=False,
               frac_load=1.0,
               marker=['circle'],
               seed=0,
               **kwargs):
        from bokeh.layouts import row, widgetbox, column, Spacer
        from bokeh.models import HoverTool, TapTool, FixedTicker, Circle, WheelZoomTool
        from bokeh.models import CustomJS, Slider, Rect, ColorBar, HoverTool, LinearColorMapper, BasicTicker
        from bokeh.plotting import figure
        import bokeh.models.markers as Bokeh_markers
        from bokeh.models import ColumnDataSource, CDSView, IndexFilter
        from bokeh.palettes import all_palettes, Spectral6, Inferno256, Viridis256, Greys256, Magma256, Plasma256
        from bokeh.palettes import Spectral, Inferno, Viridis, Greys, Magma, Plasma
        from bokeh.models import LogColorMapper, LogTicker, ColorBar, BasicTicker, LinearColorMapper
        from bokeh.models.widgets import DataTable, TableColumn, NumberFormatter, Div
        import pandas as pd
        #        if (title==''): title=self.name
        fulldata = self.pd
        idx = np.arange(len(fulldata))
        fulldata['id'] = idx
        nload = int(frac_load * len(fulldata))
        np.random.seed(seed)
        np.random.shuffle(idx)
        idload = np.sort(idx[0:nload])
        data = self.pd.iloc[idload].copy()
        if palette == 'cosmo': COLORS = cosmo()
        else: COLORS = locals()[palette]

        marklist = [
            'circle', 'diamond', 'triangle', 'square', 'asterisk', 'cross',
            'inverted_triangle'
        ]
        if not marker[0] in marklist: marker = marklist
        # TOOLS="resize,crosshair,pan,wheel_zoom,reset,tap,save,box_select,box_zoom,lasso_select"
        TOOLS = "pan,reset,tap,save,box_zoom,lasso_select"
        wheel_zoom = WheelZoomTool(dimensions='both')
        if Hover:
            proplist = []
            for prop in data.columns:
                if prop not in [
                        "CV1", "CV2", "Cv1", "Cv2", "cv1", "cv2", "colors",
                        "radii", "id"
                ]:
                    proplist.append((prop, '@' + prop))
            hover = HoverTool(names=["mycircle"], tooltips=[("id", '@id')])
            for prop in proplist:
                hover.tooltips.append(prop)
            plot = figure(title=title,
                          plot_width=pw,
                          active_scroll=wheel_zoom,
                          plot_height=ph,
                          tools=[TOOLS, hover, wheel_zoom],
                          **kwargs)
        else:
            plot = figure(title=title,
                          plot_width=pw,
                          active_scroll=wheel_zoom,
                          plot_height=ph,
                          tools=[TOOLS],
                          **kwargs)

# selection glyphs and plot styles
        mdict = {
            'circle': 'Circle',
            'diamond': 'Diamond',
            'triangle': 'Triangle',
            'square': 'Square',
            'asterisk': 'Asterisk',
            'cross': 'Cross',
            'inverted_triangle': 'InvertedTriangle'
        }
        initial_circle = Circle(x='x', y='y')
        selected_circle = getattr(Bokeh_markers,
                                  mdict[marker[0]])(fill_alpha=0.7,
                                                    fill_color="blue",
                                                    size=ps * 1.5,
                                                    line_color="blue")
        nonselected_circle = getattr(Bokeh_markers,
                                     mdict[marker[0]])(fill_alpha=alpha * 0.5,
                                                       fill_color='colors',
                                                       line_color='colors',
                                                       line_alpha=alpha * 0.5)
        # set up variable point size
        if radii == 'None':
            r = [ps for i in range(len(data))]
            data['radii'] = r
        else:
            if data[radii].dtype == 'object':  # Categorical variable for radii
                grouped = data.groupby(radii)
                i = 0
                r = np.zeros(len(data))
                for group_item in grouped.groups.keys():
                    r[grouped.groups[group_item].tolist()] = i**2
                    i = i + 2
            else:
                r = [val for val in data[radii]]
            rn = self.normalize(r)
            rad = [minps + ps * np.sqrt(val) for val in rn]
            data['radii'] = rad

# setup variable point color
        if color == 'None':
            c = ["#31AADE" for i in range(len(data))]
            data['colors'] = c
            datasrc = ColumnDataSource(data)
            getattr(plot, marker[0])(x,
                                     y,
                                     source=datasrc,
                                     size='radii',
                                     fill_color='colors',
                                     fill_alpha=alpha,
                                     line_color='colors',
                                     line_alpha=alpha,
                                     name="mycircle")
            renderer = plot.select(name="mycircle")
            renderer.selection_glyph = selected_circle
            renderer.nonselection_glyph = nonselected_circle
        else:
            if data[color].dtype == 'object':  # Categorical variable for colors
                grouped = data.groupby(color)
                # COLORS=Spectral[len(grouped)]
                i = 0
                nc = len(COLORS)
                istep = int(nc / len(grouped))
                cat_colors = []
                for group_item in grouped.groups.keys():
                    #  data.loc[grouped.groups[group_item],'colors']=COLORS[i]
                    # print(group_item,COLORS[i])
                    i = min(i + istep, nc - 1)
                    cat_colors.append(COLORS[i])
                #colors=[ '#d53e4f', '#3288bd','#fee08b', '#99d594']
                datasrc = ColumnDataSource(data)
                view = []
                # used_markers=[]
                # marker=['circle','diamond','triangle','square','asterisk','cross','inverted_triangle']
                #while True:
                #    for x in marker:
                #        used_markers.append(x)
                #    if len(used_markers)>len(grouped): break
                i = 0
                #print used_markers
                for group_item in grouped.groups.keys():
                    view.append(
                        CDSView(
                            source=datasrc,
                            filters=[IndexFilter(grouped.groups[group_item])]))
                    cname = 'mycircle' + str(i)
                    #print used_markers[i]
                    try:
                        mk = marker[i]
                    except:
                        mk = marker[0]
                    getattr(plot, mk)(x,
                                      y,
                                      source=datasrc,
                                      size='radii',
                                      fill_color=cat_colors[i],
                                      muted_color=cat_colors[i],
                                      muted_alpha=0.2,
                                      fill_alpha=alpha,
                                      line_alpha=alpha,
                                      line_color=cat_colors[i],
                                      name=cname,
                                      legend=group_item,
                                      view=view[i])
                    selected_mk = getattr(Bokeh_markers,
                                          mdict[mk])(fill_alpha=0.7,
                                                     fill_color="blue",
                                                     size=ps * 1.5,
                                                     line_color="blue",
                                                     line_alpha=0.7)
                    nonselected_mk = getattr(Bokeh_markers, mdict[mk])(
                        fill_alpha=alpha * 0.5,
                        fill_color=cat_colors[i],
                        line_color=cat_colors[i],
                        line_alpha=alpha * 0.5)
                    renderer = plot.select(name=cname)
                    renderer.selection_glyph = selected_mk
                    renderer.nonselection_glyph = nonselected_mk
                    i += 1
                plot.legend.location = "top_left"
                plot.legend.orientation = "vertical"
                plot.legend.click_policy = "hide"
            else:
                if Periodic_color:  # if periodic property then generate periodic color palatte
                    blendcolor = interpolate(COLORS[-1], COLORS[0],
                                             len(COLORS) / 5)
                    COLORS = COLORS + blendcolor
                groups = pd.cut(data[color].values, len(COLORS))
                c = [COLORS[xx] for xx in groups.codes]
                data['colors'] = c
                datasrc = ColumnDataSource(data)
                getattr(plot, marker[0])(x,
                                         y,
                                         source=datasrc,
                                         size='radii',
                                         fill_color='colors',
                                         fill_alpha=alpha,
                                         line_color='colors',
                                         line_alpha=alpha,
                                         name="mycircle")
                renderer = plot.select(name="mycircle")
                renderer.selection_glyph = selected_circle
                renderer.nonselection_glyph = nonselected_circle
                color_mapper = LinearColorMapper(COLORS,
                                                 low=data[color].min(),
                                                 high=data[color].max())
                colorbar = ColorBar(color_mapper=color_mapper,
                                    ticker=BasicTicker(),
                                    label_standoff=4,
                                    border_line_color=None,
                                    location=(0, 0),
                                    orientation="vertical")
                colorbar.background_fill_alpha = 0
                colorbar.border_line_alpha = 0
                if add_colorbar:
                    plot.add_layout(colorbar, 'left')
        # Overview plot
        oplot = figure(title='',
                       plot_width=200,
                       plot_height=200,
                       toolbar_location=None)
        oplot.circle(x,
                     y,
                     source=datasrc,
                     size=4,
                     fill_alpha=0.6,
                     line_color=None,
                     name="mycircle")
        orenderer = oplot.select(name="mycircle")
        orenderer.selection_glyph = selected_circle
        # orenderer.nonselection_glyph = nonselected_circle
        rectsource = ColumnDataSource({'xs': [], 'ys': [], 'wd': [], 'ht': []})
        jscode = """
                var data = source.data;
                var start = range.start;
                var end = range.end;
                data['%s'] = [start + (end - start) / 2];
                data['%s'] = [end - start];
                source.change.emit();
             """
        plot.x_range.callback = CustomJS(args=dict(source=rectsource,
                                                   range=plot.x_range),
                                         code=jscode % ('xs', 'wd'))
        plot.y_range.callback = CustomJS(args=dict(source=rectsource,
                                                   range=plot.y_range),
                                         code=jscode % ('ys', 'ht'))
        rect = Rect(x='xs',
                    y='ys',
                    width='wd',
                    height='ht',
                    fill_alpha=0.1,
                    line_color='black',
                    fill_color='red')
        oplot.add_glyph(rectsource, rect)

        # plot style
        plot.toolbar.logo = None
        oplot.toolbar.logo = None
        if style == 'smapstyle': plist = [plot, oplot]
        else: plist = [oplot]
        for p in plist:
            p.xgrid.grid_line_color = None
            p.ygrid.grid_line_color = None
            p.xaxis[0].ticker = FixedTicker(ticks=[])
            p.yaxis[0].ticker = FixedTicker(ticks=[])
            p.outline_line_width = 0
            p.outline_line_alpha = 0
            p.background_fill_alpha = 0
            p.border_fill_alpha = 0
            p.xaxis.axis_line_width = 0
            p.xaxis.axis_line_color = "white"
            p.yaxis.axis_line_width = 0
            p.yaxis.axis_line_color = "white"
            p.yaxis.axis_line_alpha = 0


# table
        if table:
            tcolumns = [
                TableColumn(field='id',
                            title='id',
                            formatter=NumberFormatter(format='0'))
            ]
            for prop in data.columns:
                if prop not in [
                        "CV1", "CV2", "Cv1", "Cv2", "cv1", "cv2", "colors",
                        'id', "radii"
                ]:
                    if data[prop].dtype == 'object':
                        tcolumns.append(TableColumn(field=prop, title=prop))
                    if data[prop].dtype == 'float64':
                        tcolumns.append(
                            TableColumn(
                                field=prop,
                                title=prop,
                                formatter=NumberFormatter(format='0.00')))
                    if data[prop].dtype == 'int64':
                        tcolumns.append(
                            TableColumn(field=prop,
                                        title=prop,
                                        formatter=NumberFormatter(format='0')))
            data_table = DataTable(source=datasrc,
                                   fit_columns=True,
                                   scroll_to_selection=True,
                                   columns=tcolumns,
                                   name="Property Table",
                                   width=table_width,
                                   height=table_height)
            div = Div(text="""<h6><b> Property Table </b> </h6> <br>""",
                      width=600,
                      height=10)
            if return_datasrc:
                return plot, oplot, column(widgetbox(div), Spacer(height=10),
                                           widgetbox(data_table)), datasrc
            else:
                return plot, oplot, column(widgetbox(div), Spacer(height=10),
                                           widgetbox(data_table))
        else:
            return plot, oplot