def reset_filters(): """ Resets the filter tables (e.g. after clearing search) """ logger.info("reset filters") global filter_item filter_item = -1 filter_topics_table.view.filters = [IndexFilter()] filter_custom_table.view.filters = [IndexFilter()] filter_label.text = ""
def ajoute_toggle_extrémités(self) -> None: size = 4 fill_color = "DarkSlateGray" self.g = (self.tron.set_index( self.tron_idx_name).geometry.boundary.dropna().explode().droplevel( 1).rename("geometry").reset_index().reset_index()) idx_g = self.g.columns[0] # colonne qui contient le numéro de ligne self.src_extr = GeoJSONDataSource(geojson=self.g.to_json()) self.filter_extr = IndexFilter(list(range(self.g.shape[0]))) self.index_extrémités_par_tron = ( self.tron.reset_index() # numéro de ligne dans la colonne idx_g .merge( self.g, on=self.tron_idx_name ) # inner join donc tous les tronçons non localisés n'y sont pas .groupby(f"{idx_g}_x").apply( lambda s: list(s[f"{idx_g}_y"])).to_dict()) view = CDSView(source=self.src_extr, filters=[self.filter_extr]) self.extr_renderer = self.p.circle( x="x", y="y", size=size, fill_color=fill_color, line_color=fill_color, source=self.src_extr, visible=False, view=view, ) self.toggle_extr = Toggle(label="Affiche les extrémités", button_type="success", width=100) self.toggle_extr.js_link("active", self.extr_renderer, "visible")
def hello(request): context = {} context['hello'] = 'guodegang' context['condition'] = False context['host'] = '127.0.0.1' context['port'] = 4200 context['user'] = '******' context['pwd'] = 'abc123.' source = ColumnDataSource(data=dict(x=[1, 2, 3, 4, 5], y=[1, 2, 3, 4, 5])) view = CDSView(source=source, filters=[IndexFilter([0, 2, 4])]) tools = ["box_select", "hover", "reset"] p = figure(plot_height=300, plot_width=300, tools=tools) p.circle(x="x", y="y", size=10, hover_color="red", source=source) p_filtered = figure(plot_height=300, plot_width=300, tools=tools) p_filtered.circle(x="x", y="y", size=10, hover_color="red", source=source, view=view) # TOOLS = "hover,crosshair,pan,wheel_zoom,box_zoom,reset,save,box_select" # picture = figure(width=1200, height=400, tools=TOOLS) # picture.line(data['order'], data['value'], color='blue', alpha=0.5) # script, div = components(p, CDN) print(script) print(div) context['script'] = script context['div'] = div return render(request, 'html/index.html', context)
def _view_hit(self, L_raw): """ Determine the subset hit by a labeling function. """ indices = np.where(L_raw != module_config.ABSTAIN_DECODED)[0].tolist() view = CDSView(source=self.sources["raw"], filters=[IndexFilter(indices)]) return view
def _view_missed(self, L_labeled, targets): """ Determine the subset missed by a labeling function. """ targetable = np.isin(self.dfs["labeled"]["label"], targets) abstained = L_labeled == module_config.ABSTAIN_DECODED indices = np.where(np.multiply(targetable, abstained))[0].tolist() view = CDSView(source=self.sources["labeled"], filters=[IndexFilter(indices)]) return view
def _view_incorrect(self, L_labeled): """ Determine the subset incorrectly labeled by a labeling function. """ disagreed = self.dfs["labeled"]["label"].values != L_labeled attempted = L_labeled != module_config.ABSTAIN_DECODED indices = np.where(np.multiply(disagreed, attempted))[0].tolist() view = CDSView(source=self.sources["labeled"], filters=[IndexFilter(indices)]) return view
def test_check_cdsview_filters_with_connected_error(self, glyph) -> None: renderer = bmr.GlyphRenderer(data_source=ColumnDataSource()) renderer.glyph = glyph() check = renderer._check_cdsview_filters_with_connected() assert check == [] renderer.view.filters = [IndexFilter()] check = renderer._check_cdsview_filters_with_connected() assert check != []
def test_check_cdsview_filters_with_connected_error(self, glyph): renderer = bmr.GlyphRenderer() renderer.glyph = glyph() check = renderer._check_cdsview_filters_with_connected() assert check == [] renderer.view.filters = [IndexFilter()] check = renderer._check_cdsview_filters_with_connected() assert check != []
def plot(): df = pd.read_csv('test2/sent_stock_data3.csv') df['date_time'] = pd.to_datetime(df['date_time']) #df.info() part = df source = ColumnDataSource( data=dict(x=part.date_time, y=part.stock_price, z=part.avg_sentiment)) plot = figure(title="Stock price with sentiment graph", x_axis_label='Date', y_axis_label='Stock price', x_axis_type='datetime') plot.line(x='x', y='y', source=source, line_color='blue', line_width=5) sent = "neutral" for a in range(len(part.date_time)): view = CDSView(source=source, filters=[IndexFilter([a])]) if source.data['z'][a] > 0.1: plot.circle(source.data['x'][a], source.data['y'][a], source=source, view=view, fill_color='green', size=50) elif source.data['z'][a] < -0.1: plot.circle(source.data['x'][a], source.data['y'][a], source=source, view=view, fill_color='red', size=50) else: plot.circle(source.data['x'][a], source.data['y'][a], source=source, view=view, fill_color='blue', size=50) hover = HoverTool() hover.tooltips = [('Average sentiment', sent), ('Exact price', '@y')] plot.add_tools(hover) script1, div1 = components(plot) cdn_js = CDN.js_files #cdn_js[0] only need this link cdn_json = cdn_js[0] return render_template('graph.html', script1=script1, div1=div1, cdn_json=cdn_json)
def create_view(source, group_index, subgroup_col, widget): filters = [IndexFilter(group_index)] if isinstance(widget, Row): # this means we have checkbox groups checkboxes = widget.children[1] checkbox_filter = _checkbox_filter(checkboxes, source, subgroup_col) filters.append(checkbox_filter) elif isinstance(widget, RangeSlider): group_slider_filter = _slider_filter(widget, source, subgroup_col) filters.append(group_slider_filter) view = CDSView(source=source, filters=filters) return view
def build_venn_figure(self): TOOLS = "hover,pan,wheel_zoom,zoom_in,zoom_out,box_zoom,reset,tap,save," TOOLTIPS = [ ("index", "$index"), ("Name", "@name"), ("Size", "@approximate_count"), ] self.view = CDSView(source=self.asource, filters=[IndexFilter([0])]) p = figure(tools=TOOLS, y_range=[-2, 2], x_range=[-2, 2], plot_width=700, plot_height=600, tooltips=TOOLTIPS) # self.osource = ColumnDataSource(self.source.data) p.circle(x='x', y='y', radius='r', line_color='c', source=self.asource, view=self.view, fill_alpha=0.2, selection_line_color="firebrick", nonselection_fill_color=None, selection_fill_color=None, fill_color=None) p.text( x='x', y='y', text='approximate_count_txt', source=self.asource, text_font_size='12px', text_align='center', text_baseline='middle', text_color='c', view=self.view, selection_text_color="firebrick", ) # l = column(s, p) p.xgrid.visible = False p.ygrid.visible = False p.xaxis.major_tick_line_color = None # turn off x-axis major ticks p.xaxis.minor_tick_line_color = None # turn off x-axis minor ticks p.yaxis.major_tick_line_color = None # turn off y-axis major ticks p.yaxis.minor_tick_line_color = None # turn off y-axis minor ticks # output_file("color_scatter.html", title="color_scatter.py example") p.xaxis.major_label_text_font_size = '0pt' # turn off x-axis tick labels p.yaxis.major_label_text_font_size = '0pt' # turn off y-axis tick labels return p
def plot(self, label, **kwargs): """ Plot the margins about a single label. """ for _key, _source in self.sources.items(): # prepare plot settings eff_kwargs = self.glyph_kwargs[_key].copy() eff_kwargs.update(kwargs) eff_kwargs["legend_label"] = f"{label}" # create agreement/increment/decrement subsets col_a_pos = np.where( self.dfs[_key][self.label_col_a] == label)[0].tolist() col_a_neg = np.where( self.dfs[_key][self.label_col_a] != label)[0].tolist() col_b_pos = np.where( self.dfs[_key][self.label_col_b] == label)[0].tolist() col_b_neg = np.where( self.dfs[_key][self.label_col_b] != label)[0].tolist() agreement_view = CDSView( source=_source, filters=[IndexFilter(col_a_pos), IndexFilter(col_b_pos)]) increment_view = CDSView( source=_source, filters=[IndexFilter(col_a_neg), IndexFilter(col_b_pos)]) decrement_view = CDSView( source=_source, filters=[IndexFilter(col_a_pos), IndexFilter(col_b_neg)]) to_plot = [ { "view": agreement_view, "marker": self.figure.square }, { "view": increment_view, "marker": self.figure.x }, { "view": decrement_view, "marker": self.figure.cross }, ] # plot created subsets for _dict in to_plot: _view = _dict["view"] _marker = _dict["marker"] _marker("x", "y", name=_key, source=_source, view=_view, **eff_kwargs)
def get_threshold_summary_plot(ds): resultsdir = ds.config('DYESCORE_RESULTS_DIR') inpath = os.path.join(resultsdir, f'recall_summary_plot_data.csv') ds.file_in_validation(inpath) if ds.s3: with ds.s3.open(inpath, 'r') as f: results_df = pd_read_csv(f) else: results_df = pd_read_csv(inpath) recall_thresholds = sorted(results_df.recall_threshold.unique()) grouped_results_df = results_df.groupby('recall_threshold').agg( lambda x: list(x)) palette = inferno(len(recall_thresholds) + 1) # The yellow is often a little light source = ColumnDataSource(grouped_results_df) p = figure( title= f'Scripts captured by distance threshold for {len(recall_thresholds)} recall thresholds (colored)', width=800, toolbar_location=None, tools='', y_range=Range1d(results_df.n_over_threshold.min(), results_df.n_over_threshold.max()), ) p.xaxis.axis_label = 'distance threshold' p.yaxis.axis_label = 'minimum n_scripts' p.yaxis.formatter = NumeralTickFormatter(format="0a") p.extra_y_ranges = { 'percent': Range1d(results_df.percent.min(), results_df.percent.max()) } p.add_layout( LinearAxis(y_range_name='percent', axis_label='minimum n_scripts (percent of total)', formatter=NumeralTickFormatter(format='0%')), 'right') for i, recall_threshold in enumerate(recall_thresholds): view = CDSView(source=source, filters=[IndexFilter([i])]) opts = dict(source=source, view=view, legend=str(recall_threshold), color=palette[i], line_width=5, line_alpha=0.6) p.multi_line(xs='distance_threshold', ys='n_over_threshold', **opts) p.multi_line(xs='distance_threshold', ys='percent', y_range_name='percent', **opts) p.legend.click_policy = 'hide' return p
def selection_callback(attr, old, new): # indicies of selected points inds = np.array(new['1d']['indices']) # if no points selected, table_source CDSView is unfiltered global table_source if len(inds) == 0: table_source = CDSView(source=source) # otherwise filter the table_source CDSView to show only points selected else: table_source = CDSView(source=source, filters=[IndexFilter(inds)]) # update third panel of layout to show new table layout.children[2] = create_table()
def init_layout(self) -> None: super().init_layout() geojson = self.tron.to_json().replace( "null", '{"type":"Point","coordinates":[]}') self.source_lines = GeoJSONDataSource(geojson=geojson) self.filter = IndexFilter(list(range(self.tron.shape[0]))) self.ajoute_lignes() self.ajoute_toggle_extrémités() self.source_lines.selected.js_on_change("indices", self.callback_selected) self.ajoute_table_tron() self.ajoute_input_num() self.hover_tool = self.p.select(type=HoverTool) self.hover_tool.names = ["tronçons"] self.ajoute_légende() self.première_ligne = row(self.input_num, self.toggle_extr)
def update(attr, old, new): capteur_to_plot = [ capteur_selection.labels[i] for i in capteur_selection.active ] l = text_input.value L_text = [] for val in l.split('\n'): L_text.append(val) text_input_start = L_text[1] text_input_end = L_text[4] nom_capteur = select.value new_src = make_dataset(capteur_to_plot, text_input_start, text_input_end, nom_capteur, L_influx) new_source = ColumnDataSource(new_src) dictio = {} for i in range(0, len(new_src.columns)): dictio[new_src.columns[i]] = new_src[new_src.columns[i]] source.data = dictio table_columns = [ TableColumn(field='Date', title='Date', formatter=DateFormatter(format="%m/%d/%Y %H:%M:%S")) ] colonne = new_src.columns colonne = colonne.delete(0) table_columns += [TableColumn(field=col, title=col) for col in colonne] datatable.columns = table_columns liste = [k for k in range(0, 10)] longueur = new_src.shape[0] for k in range(longueur - 10, longueur): liste.append(k) view1 = CDSView(source=new_source, filters=[IndexFilter(indices=liste)]) datatable.view = view1
def create_sample_scatter(x_data, y_data, source, title='', x_axis_title='', y_axis_title=''): result_plot = figure(title=title, tools=tools_list, tooltips=custom_tooltip) result_plot.xaxis.axis_label = x_axis_title result_plot.yaxis.axis_label = y_axis_title for label in roman_label: index_list = [] legend_label = '' for i in range(len(source.data['style_label'])): if source.data['style_label'][i] == label: index_list.append(i) legend_label = source.data['legend_label'][i] view = CDSView(source=source, filters=[IndexFilter(index_list)]) result_plot.scatter( x_data, y_data, source=source, fill_alpha=0.4, size=12, # marker=factor_mark('style_label', markers, roman_label), marker='circle', color=factor_cmap('style_label', 'Category20_16', roman_label), # muted_color=factor_cmap(label['real_label_list'], 'Category10_8', # label['standard_label_list']), muted_alpha=0.1, view=view, legend_label=legend_label) result_plot.legend.click_policy = "hide" # highlight x y axes result_plot.renderers.extend([vline, hline]) return result_plot
def create_sample_scatter(x_data, y_data, source, label, title='', x_axis_title='', y_axis_title=''): result_plot = figure(title=title, tools=tools_list, tooltips=custom_tooltip) result_plot.xaxis.axis_label = x_axis_title result_plot.yaxis.axis_label = y_axis_title for cat_filter in label['standard_label_list']: index_list = [] for i in range(len(source.data[label['real_label_list']])): if source.data[label['real_label_list']][i] == cat_filter: index_list.append(i) view = CDSView(source=source, filters=[IndexFilter(index_list)]) result_plot.scatter( x_data, y_data, source=source, fill_alpha=0.4, size=8, marker=factor_mark(label['real_label_list'], markers, label['standard_label_list']), color=factor_cmap(label['real_label_list'], 'Category10_8', label['standard_label_list']), # muted_color=factor_cmap(label['real_label_list'], 'Category10_8', # label['standard_label_list']), muted_alpha=0.1, view=view, legend=cat_filter) result_plot.legend.click_policy = "mute" # highlight x y axes result_plot.renderers.extend([vline, hline]) return result_plot
def make_table(source, src): table_columns = [ TableColumn(field='Date', title='Date', formatter=DateFormatter(format="%m/%d/%Y %H:%M:%S")) ] colonne = src.columns colonne = colonne.delete(0) table_columns += [TableColumn(field=col, title=col) for col in colonne] liste = [k for k in range(0, 10)] longueur = src.shape[0] for k in range(longueur - 10, longueur): liste.append(k) view1 = CDSView(source=source, filters=[IndexFilter(indices=liste)]) #table_source = ColumnDataSource(src) datatable = DataTable(source=source, columns=table_columns, width=1200, height=1000, view=view1) #print(datatable.fit_columns) return datatable
plot = Plot(width=1000, x_range=xdr, y_range=ydr, toolbar_location=None) plot.title.text = "Usain Bolt vs. 116 years of Olympic sprinters" xticker = SingleIntervalTicker(interval=5, num_minor_ticks=0) xaxis = LinearAxis(ticker=xticker, axis_label="Meters behind 2012 Bolt") plot.add_layout(xaxis, "below") xgrid = Grid(dimension=0, ticker=xaxis.ticker) plot.add_layout(xgrid) yticker = SingleIntervalTicker(interval=12, num_minor_ticks=0) yaxis = LinearAxis(ticker=yticker, major_tick_in=-5, major_tick_out=10) plot.add_layout(yaxis, "right") filters = [IndexFilter(list(sprint.query('Medal == "gold" and Year in [1988, 1968, 1936, 1896]').index))] medal = Circle(x="MetersBack", y="Year", size=10, fill_color="MedalFill", line_color="MedalLine", fill_alpha=0.5) medal_renderer = plot.add_glyph(source, medal) #sprint[sprint.Medal=="gold" * sprint.Year in [1988, 1968, 1936, 1896]] plot.add_glyph(source, Text(x="MetersBack", y="Year", x_offset=10, text="Name"), view=CDSView(source=source, filters=filters)) plot.add_glyph(source, Text(x=7.5, y=1942, text=["No Olympics in 1940 or 1944"], text_font_style="italic", text_color="silver")) tooltips = """ <div> <span style="font-size: 15px;">@Name</span> <span style="font-size: 10px; color: #666;">(@Abbrev)</span> </div>
from bokeh.layouts import gridplot from bokeh.models import CDSView, ColumnDataSource, IndexFilter from bokeh.plotting import figure, show source = ColumnDataSource(data=dict(x=[1, 2, 3, 4, 5], y=[1, 2, 3, 4, 5])) view = CDSView(filters=[IndexFilter([0, 2, 4])]) tools = ["box_select", "hover", "reset"] p = figure(height=300, width=300, tools=tools) p.circle(x="x", y="y", size=10, hover_color="red", source=source) p_filtered = figure(height=300, width=300, tools=tools) p_filtered.circle(x="x", y="y", size=10, hover_color="red", source=source, view=view) show(gridplot([[p, p_filtered]]))
def slider_handler(attr, old, new): # view.filters[0] = IndexFilter([new]) # print(view.filters[0]) img_plot.view = CDSView(source=source, filters=[IndexFilter([new])])
# yrs=yrs.apply(str) # yrs = yrs.tolist() x = oo_Year_Gender.Year #x data: years #Figure Stuff p = figure( x_range=[1895, 2009], y_range=[-10, 1300], x_axis_label='Year', y_axis_label='Medal Count', #"pan,box_zoom,reset,save,hover" title="Total Medal Count by Gender per Year") #Filter Data for World War cancellation view = CDSView(filters=[ IndexFilter([ 0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28 ]) ]) patch1 = p.patch(x=[1912, 1912, 1920, 1920], y=[-10, 1300, 1300, -10], color='black', alpha=0.04, line_width=0) patch2 = p.patch(x=[1936, 1936, 1948, 1948], y=[-10, 1300, 1300, -10], color='black', alpha=0.04, line_width=0) p.add_tools( HoverTool(renderers=[patch1], tooltips=[
sum_plot.circle("xsum-circle", "ysum-circle", radius="r", line_color="color", line_width=2, line_dash="dashing", fill_color=None, source=items_source) sum_plot.circle("xsum-dot", "ysum-dot", size=5, color="color", source=items_source) segment_view = CDSView(source=items_source, filters=[IndexFilter([3])]) sum_plot.segment(x0="xsum-dot", y0="ysum-dot", x1=2.5, y1="ysum-dot", color="orange", source=items_source, view=segment_view) sum_plot.xgrid.bounds = sum_plot.xaxis.bounds = (-2.5, 2.5) sum_plot.axis.ticker.desired_num_ticks = 8 @repeat(range(N)) def update(ind): ykeys = (k for k in lines_source.data.keys() if k.startswith("y"))
from bokeh.layouts import gridplot from bokeh.models import CDSView, ColumnDataSource, IndexFilter from bokeh.plotting import figure, show # create ColumnDataSource from a dict source = ColumnDataSource(data=dict(x=[1, 2, 3, 4, 5], y=[1, 2, 3, 4, 5])) # create a view using an IndexFilter with the index positions [0, 2, 4] view = CDSView(filter=IndexFilter([0, 2, 4])) # setup tools tools = ["box_select", "hover", "reset"] # create a first plot with all data in the ColumnDataSource p = figure(height=300, width=300, tools=tools) p.circle(x="x", y="y", size=10, hover_color="red", source=source) # create a second plot with a subset of ColumnDataSource, based on view p_filtered = figure(height=300, width=300, tools=tools) p_filtered.circle(x="x", y="y", size=10, hover_color="red", source=source, view=view) # show both plots next to each other in a gridplot layout show(gridplot([[p, p_filtered]]))
def show_overlaps(self): try: self.show_btn.disabled = True maxaud = self.maxaud_inp.value thresh = self.thresh_inp.value rankby = self.rankby_inp.value method = self.method_inp.value substrings = self.substring_inp.value.split(',') pivot = self.pivot_inp.value pividx = 0 idxs = [] for idx, name in enumerate(self.asource.data["name"]): if name == pivot: pividx = idx elif any([s in name for s in substrings] + ['' in substrings]): idxs.append(idx) elif idx in self.asource.selected.indices: idxs.append(idx) # print(pividx) # idxs = self.asource.selected.indices if len(idxs) < 1: raise RuntimeError('Nothing selected') if pividx in idxs: idxs.remove(pividx) data = { k: [vs[idx] for idx in idxs] for k, vs in self.asource.data.items() } overlaps = self.data["overlaps"].to_dict() # print(overlaps) df = pd.DataFrame(data) pivid = self.asource.data["id"][pividx] df["overlap"] = [overlaps[pivid][aid] for aid in df["id"]] sortby = { 'size': 'approximate_count', "overlap": "overlap" }[rankby] df["idx"] = idxs df = df.sort_values(sortby, ascending=False) data = df.to_dict(orient='list') for k, vs in self.asource.data.items(): data[k] = [vs[pividx]] + data[k] data["idx"] = [pividx] + data["idx"] data["overlap"] = [data["approximate_count"][0]] + data["overlap"] norm = np.max(data["approximate_count"]) # sizes = {aid:size for aid, size in zip(self.asource.data["id"], self.asource.data["approximate_count"])} os = [[overlaps[audid1][audid2] / norm for audid2 in data['id']] for audid1 in data['id']] rs = np.sqrt(np.array(data["approximate_count"]) / norm) xs, ys = place_circles(rs, os, thresh, method=method) p = { 'x': [(idx, x) for idx, x in zip(data["idx"], xs)], 'y': [(idx, y) for idx, y in zip(data["idx"], ys)], 'r': [(idx, r) for idx, r in zip(data["idx"], rs)] } self.asource.patch(p) self.update_heatmap(data["idx"]) self.view.filters = [IndexFilter(data["idx"])] self.tabs.active = 1 self.asource.selected.update(indices=[]) except Exception as e: # raise e print(e) finally: self.show_btn.disabled = False
input_test_img = img.imread("Bokeh//test_app//static//frame(30)_1.bmp") input_test_img = np.ascontiguousarray(np.flipud( np.stack( (input_test_img[..., 0], input_test_img[..., 1], input_test_img[..., 2], np.ones_like(input_test_img[..., 0]) * 255), axis=-1)), dtype=np.uint8).view( dtype=np.uint32).reshape( (input_test_img.shape[0], input_test_img.shape[1])) # create data sources source = ColumnDataSource(data=dict( image=[input_test_img], start_x=[0], start_y=[0], width=[10], height=[10])) view = CDSView(source=source, filters=[IndexFilter([0])]) # view_callback_js = CustomJS( # args=dict(), # code=""" # var filter = this[0] # console.log('filter_ind: value=' + filter.indices[0]) # """ # ) # view.js_on_change('filters', view_callback_js) # create plot img_plot = img_fig.image_rgba(source=source, view=view, image='image', x='start_x', y='start_y',
topic_clustering_after = None white_list = [] custom_nps = [] ref_data_name = "" target_data_name = "" apply_filter = False dont_regenerate = False filter_item = -1 all_topics = [] filter_rows = [] table_columns = [TableColumn(field="topics", title="Topics")] # create ui components filter_topics_table_source = ColumnDataSource(data={}) view1 = CDSView(source=filter_topics_table_source, filters=[IndexFilter()]) filter_topics_table = DataTable(source=filter_topics_table_source, view=view1, columns=table_columns, width=500, selectable=True, scroll_to_selection=True, css_classes=['filter_topics_table']) filter_custom_table_source = ColumnDataSource(data={}) filter_custom_table = DataTable(source=filter_custom_table_source, columns=table_columns, width=500, selectable=True, css_classes=['filter_custom_table']) radio_group_area = RadioGroup(labels=[ "All", "Top Topics", "Trends", "Trend Clustering", "Custom Trends",
def file_handler(attr, old, new): add_image(new, source, filepicker.filename) ind = img_plot.view.filters[0].indices img_plot.view = CDSView(source=source, filters=[IndexFilter(ind)])
def plot_bokeh(self, plot_name=None, show_plot=False, barplot=True, chng=True): """ Plot OOB-scores as bar- or linechart in bokeh. Parameters ---------- plot_name: str path where to store the plot, None to not save it show_plot: bool whether or not to open plot in standard browser barplot: bool plot OOB as barchart chng: bool plot OOB as linechart Returns ------- layout: bokeh.models.Row bokeh plot (can be used in notebook or comparted with components) """ # Get all relevant data-points params = list(self.evaluated_parameter_importance.keys()) p_names_short = shorten_unique(params) errors = list(self.evaluated_parameter_importance.values()) max_to_plot = min(len(errors), self.MAX_PARAMS_TO_PLOT) plot_indices = sorted(range(len(errors)), key=lambda x: errors[x], reverse=True)[:max_to_plot] # Customizing plot-style bar_width = 25 # Create ColumnDataSource for both plots source = ColumnDataSource(data=dict( parameter_names_short=p_names_short, parameter_names=params, parameter_importance=errors, )) plots = [] view = CDSView(source=source, filters=[IndexFilter(plot_indices)]) tooltips = [ ("Parameter", "@parameter_names"), ("Importance", "@parameter_importance"), ] if barplot: p = figure(x_range=p_names_short, plot_height=350, plot_width=100 + max_to_plot * bar_width, toolbar_location=None, tools="hover", tooltips=tooltips) p.vbar(x='parameter_names_short', top='parameter_importance', width=0.9, source=source, view=view) for value in [ self.IMPORTANCE_THRESHOLD, -self.IMPORTANCE_THRESHOLD ]: p.add_layout( Span(location=value, dimension='width', line_color='red', line_dash='dashed')) plots.append(p) if chng: p = figure(x_range=p_names_short, plot_height=350, plot_width=100 + max_to_plot * bar_width, toolbar_location=None, tools="hover", tooltips=tooltips) p.line(x='parameter_names_short', y='parameter_importance', source=source, view=view) plots.append(p) # Common styling: for p in plots: p.xaxis.major_label_orientation = 1.3 p.xaxis.major_label_text_font_size = "14pt" p.yaxis.formatter = BasicTickFormatter(use_scientific=False) p.yaxis.axis_label = 'CV-RMSE' if self.cv else 'OOB' layout = Row(*plots) # Save and show... save_and_show(plot_name, show_plot, layout) return layout