def plotConfusionMatrix(df, width, height): from bokeh.palettes import Blues8 # Had a specific mapper to map color with value mapper = LinearColorMapper(palette=Blues8[::-1], low=df.value.min(), high=df.value.max()) TOOLS = "hover,save,reset" # Define a figure p = figure( plot_width=width, plot_height=height, # title="", x_range=list(df.Treatment.drop_duplicates()), y_range=list(df.Prediction.drop_duplicates()), toolbar_location='above', tools=TOOLS, tooltips=[('Counts', '@value')], x_axis_location="below") p.xaxis.axis_label = "Prediction" p.yaxis.axis_label = "Truth" # Create rectangle for heatmap p.rect(x="Prediction", y="Treatment", width=1, height=1, source=ColumnDataSource(df), line_color=None, fill_color=transform('value', mapper)) # Add legend color_bar = ColorBar(color_mapper=mapper, location=(0, 0), label_standoff=12, border_line_color=None, ticker=BasicTicker(desired_num_ticks=len(Blues8))) color_bar.background_fill_alpha = 0.0 p.add_layout(color_bar, 'right') p.background_fill_alpha = 0.0 p.border_fill_alpha = 0.0 return p
def bkplot(self, x, y, color='None', radii='None', ps=20, minps=0, alpha=0.8, pw=600, ph=400, palette='Inferno256', style='smapstyle', Hover=True, title='', table=False, table_width=600, table_height=150, add_colorbar=True, Periodic_color=False, return_datasrc=False, frac_load=1.0, marker=['circle'], seed=0, **kwargs): from bokeh.layouts import row, widgetbox, column, Spacer from bokeh.models import HoverTool, TapTool, FixedTicker, Circle, WheelZoomTool from bokeh.models import CustomJS, Slider, Rect, ColorBar, HoverTool, LinearColorMapper, BasicTicker from bokeh.plotting import figure import bokeh.models.markers as Bokeh_markers from bokeh.models import ColumnDataSource, CDSView, IndexFilter from bokeh.palettes import all_palettes, Spectral6, Inferno256, Viridis256, Greys256, Magma256, Plasma256 from bokeh.palettes import Spectral, Inferno, Viridis, Greys, Magma, Plasma from bokeh.models import LogColorMapper, LogTicker, ColorBar, BasicTicker, LinearColorMapper from bokeh.models.widgets import DataTable, TableColumn, NumberFormatter, Div import pandas as pd # if (title==''): title=self.name fulldata = self.pd idx = np.arange(len(fulldata)) fulldata['id'] = idx nload = int(frac_load * len(fulldata)) np.random.seed(seed) np.random.shuffle(idx) idload = np.sort(idx[0:nload]) data = self.pd.iloc[idload].copy() if palette == 'cosmo': COLORS = cosmo() else: COLORS = locals()[palette] marklist = [ 'circle', 'diamond', 'triangle', 'square', 'asterisk', 'cross', 'inverted_triangle' ] if not marker[0] in marklist: marker = marklist # TOOLS="resize,crosshair,pan,wheel_zoom,reset,tap,save,box_select,box_zoom,lasso_select" TOOLS = "pan,reset,tap,save,box_zoom,lasso_select" wheel_zoom = WheelZoomTool(dimensions='both') if Hover: proplist = [] for prop in data.columns: if prop not in [ "CV1", "CV2", "Cv1", "Cv2", "cv1", "cv2", "colors", "radii", "id" ]: proplist.append((prop, '@' + prop)) hover = HoverTool(names=["mycircle"], tooltips=[("id", '@id')]) for prop in proplist: hover.tooltips.append(prop) plot = figure(title=title, plot_width=pw, active_scroll=wheel_zoom, plot_height=ph, tools=[TOOLS, hover, wheel_zoom], **kwargs) else: plot = figure(title=title, plot_width=pw, active_scroll=wheel_zoom, plot_height=ph, tools=[TOOLS], **kwargs) # selection glyphs and plot styles mdict = { 'circle': 'Circle', 'diamond': 'Diamond', 'triangle': 'Triangle', 'square': 'Square', 'asterisk': 'Asterisk', 'cross': 'Cross', 'inverted_triangle': 'InvertedTriangle' } initial_circle = Circle(x='x', y='y') selected_circle = getattr(Bokeh_markers, mdict[marker[0]])(fill_alpha=0.7, fill_color="blue", size=ps * 1.5, line_color="blue") nonselected_circle = getattr(Bokeh_markers, mdict[marker[0]])(fill_alpha=alpha * 0.5, fill_color='colors', line_color='colors', line_alpha=alpha * 0.5) # set up variable point size if radii == 'None': r = [ps for i in range(len(data))] data['radii'] = r else: if data[radii].dtype == 'object': # Categorical variable for radii grouped = data.groupby(radii) i = 0 r = np.zeros(len(data)) for group_item in grouped.groups.keys(): r[grouped.groups[group_item].tolist()] = i**2 i = i + 2 else: r = [val for val in data[radii]] rn = self.normalize(r) rad = [minps + ps * np.sqrt(val) for val in rn] data['radii'] = rad # setup variable point color if color == 'None': c = ["#31AADE" for i in range(len(data))] data['colors'] = c datasrc = ColumnDataSource(data) getattr(plot, marker[0])(x, y, source=datasrc, size='radii', fill_color='colors', fill_alpha=alpha, line_color='colors', line_alpha=alpha, name="mycircle") renderer = plot.select(name="mycircle") renderer.selection_glyph = selected_circle renderer.nonselection_glyph = nonselected_circle else: if data[color].dtype == 'object': # Categorical variable for colors grouped = data.groupby(color) # COLORS=Spectral[len(grouped)] i = 0 nc = len(COLORS) istep = int(nc / len(grouped)) cat_colors = [] for group_item in grouped.groups.keys(): # data.loc[grouped.groups[group_item],'colors']=COLORS[i] # print(group_item,COLORS[i]) i = min(i + istep, nc - 1) cat_colors.append(COLORS[i]) #colors=[ '#d53e4f', '#3288bd','#fee08b', '#99d594'] datasrc = ColumnDataSource(data) view = [] # used_markers=[] # marker=['circle','diamond','triangle','square','asterisk','cross','inverted_triangle'] #while True: # for x in marker: # used_markers.append(x) # if len(used_markers)>len(grouped): break i = 0 #print used_markers for group_item in grouped.groups.keys(): view.append( CDSView( source=datasrc, filters=[IndexFilter(grouped.groups[group_item])])) cname = 'mycircle' + str(i) #print used_markers[i] try: mk = marker[i] except: mk = marker[0] getattr(plot, mk)(x, y, source=datasrc, size='radii', fill_color=cat_colors[i], muted_color=cat_colors[i], muted_alpha=0.2, fill_alpha=alpha, line_alpha=alpha, line_color=cat_colors[i], name=cname, legend=group_item, view=view[i]) selected_mk = getattr(Bokeh_markers, mdict[mk])(fill_alpha=0.7, fill_color="blue", size=ps * 1.5, line_color="blue", line_alpha=0.7) nonselected_mk = getattr(Bokeh_markers, mdict[mk])( fill_alpha=alpha * 0.5, fill_color=cat_colors[i], line_color=cat_colors[i], line_alpha=alpha * 0.5) renderer = plot.select(name=cname) renderer.selection_glyph = selected_mk renderer.nonselection_glyph = nonselected_mk i += 1 plot.legend.location = "top_left" plot.legend.orientation = "vertical" plot.legend.click_policy = "hide" else: if Periodic_color: # if periodic property then generate periodic color palatte blendcolor = interpolate(COLORS[-1], COLORS[0], len(COLORS) / 5) COLORS = COLORS + blendcolor groups = pd.cut(data[color].values, len(COLORS)) c = [COLORS[xx] for xx in groups.codes] data['colors'] = c datasrc = ColumnDataSource(data) getattr(plot, marker[0])(x, y, source=datasrc, size='radii', fill_color='colors', fill_alpha=alpha, line_color='colors', line_alpha=alpha, name="mycircle") renderer = plot.select(name="mycircle") renderer.selection_glyph = selected_circle renderer.nonselection_glyph = nonselected_circle color_mapper = LinearColorMapper(COLORS, low=data[color].min(), high=data[color].max()) colorbar = ColorBar(color_mapper=color_mapper, ticker=BasicTicker(), label_standoff=4, border_line_color=None, location=(0, 0), orientation="vertical") colorbar.background_fill_alpha = 0 colorbar.border_line_alpha = 0 if add_colorbar: plot.add_layout(colorbar, 'left') # Overview plot oplot = figure(title='', plot_width=200, plot_height=200, toolbar_location=None) oplot.circle(x, y, source=datasrc, size=4, fill_alpha=0.6, line_color=None, name="mycircle") orenderer = oplot.select(name="mycircle") orenderer.selection_glyph = selected_circle # orenderer.nonselection_glyph = nonselected_circle rectsource = ColumnDataSource({'xs': [], 'ys': [], 'wd': [], 'ht': []}) jscode = """ var data = source.data; var start = range.start; var end = range.end; data['%s'] = [start + (end - start) / 2]; data['%s'] = [end - start]; source.change.emit(); """ plot.x_range.callback = CustomJS(args=dict(source=rectsource, range=plot.x_range), code=jscode % ('xs', 'wd')) plot.y_range.callback = CustomJS(args=dict(source=rectsource, range=plot.y_range), code=jscode % ('ys', 'ht')) rect = Rect(x='xs', y='ys', width='wd', height='ht', fill_alpha=0.1, line_color='black', fill_color='red') oplot.add_glyph(rectsource, rect) # plot style plot.toolbar.logo = None oplot.toolbar.logo = None if style == 'smapstyle': plist = [plot, oplot] else: plist = [oplot] for p in plist: p.xgrid.grid_line_color = None p.ygrid.grid_line_color = None p.xaxis[0].ticker = FixedTicker(ticks=[]) p.yaxis[0].ticker = FixedTicker(ticks=[]) p.outline_line_width = 0 p.outline_line_alpha = 0 p.background_fill_alpha = 0 p.border_fill_alpha = 0 p.xaxis.axis_line_width = 0 p.xaxis.axis_line_color = "white" p.yaxis.axis_line_width = 0 p.yaxis.axis_line_color = "white" p.yaxis.axis_line_alpha = 0 # table if table: tcolumns = [ TableColumn(field='id', title='id', formatter=NumberFormatter(format='0')) ] for prop in data.columns: if prop not in [ "CV1", "CV2", "Cv1", "Cv2", "cv1", "cv2", "colors", 'id', "radii" ]: if data[prop].dtype == 'object': tcolumns.append(TableColumn(field=prop, title=prop)) if data[prop].dtype == 'float64': tcolumns.append( TableColumn( field=prop, title=prop, formatter=NumberFormatter(format='0.00'))) if data[prop].dtype == 'int64': tcolumns.append( TableColumn(field=prop, title=prop, formatter=NumberFormatter(format='0'))) data_table = DataTable(source=datasrc, fit_columns=True, scroll_to_selection=True, columns=tcolumns, name="Property Table", width=table_width, height=table_height) div = Div(text="""<h6><b> Property Table </b> </h6> <br>""", width=600, height=10) if return_datasrc: return plot, oplot, column(widgetbox(div), Spacer(height=10), widgetbox(data_table)), datasrc else: return plot, oplot, column(widgetbox(div), Spacer(height=10), widgetbox(data_table)) else: return plot, oplot